private DecisionNodeAccuracy CalculateDecisionTree(List <Episode> episodes) { List <DataPoint> dataPoints = EpisodesToDataPoints(episodes); using (DecisionLearner decisionLearner = new DecisionLearner(_cudaManager, dataPoints)) { DataNodeAccuracy rootAccuracy = decisionLearner.FitDecisionTree(); double gpuAccuracy = rootAccuracy.Accuracy; IDecisionNode decisionTree = NodeToDecisionTree(rootAccuracy.Node); #if DEBUGCUDA SanityCheckDecisionTree(rootAccuracy.Node, decisionTree, gpuAccuracy, episodes); #endif return(new DecisionNodeAccuracy { Root = decisionTree, Accuracy = gpuAccuracy, }); } }
public DataNodeAccuracy FitDecisionTree() { Initialize(); Kernels["dlInitDataPoints"].Execute(Context.DataPoints.Count); Kernels["dlInitialNode"].ExecuteTask(); for (int i = 0; i < GPUConstants.MaxLevels; ++i) { Kernels["dlLeafOpenNodes"] .Execute(NumSmallBlocks * GPUConstants.NumThreadsPerBlock, GPUConstants.NumThreadsPerBlock); bool isFinalLevel = i == GPUConstants.MaxLevels - 1; if (isFinalLevel) { break; // Can't split anymore } int numSplits = 0; numSplits += _attributeSplitter.FindOptimalSplit(_allSplits, numSplits); numSplits += _categoricalSplitter.FindOptimalSplit(_allSplits, numSplits); Kernels["dlFindOptimalSplitPerNode"] .Arguments(_allSplits, numSplits, _bestSplits) .Execute(NumSmallBlocks * GPUConstants.NumThreadsPerBlock, GPUConstants.NumThreadsPerBlock); _attributeSplitter.ApplyOptimalSplit(_bestSplits); _categoricalSplitter.ApplyOptimalSplit(_bestSplits); Kernels["dlNextLevel"] .Execute(NumSmallBlocks * GPUConstants.NumThreadsPerBlock, GPUConstants.NumThreadsPerBlock); } GPUNode[] gpuNodes = Context.Nodes.Read(); #if DEBUGCUDA SanityCheckGPUNodes(gpuNodes); #endif IDataNode unpruned = ReadDecisionTree(gpuNodes, 0); DataNodeAccuracy root = PruneDecisionTree(unpruned); return(root); }
private DataNodeAccuracy PruneDecisionTree(IDataNode node) { if (node is DataLeaf) { return(new DataNodeAccuracy { Node = node, CorrectWeight = node.ClassDistribution.Max(), TotalWeight = node.ClassDistribution.Sum(), }); } else if (node is IDataSplit) { IDataSplit split = (IDataSplit)node; DataNodeAccuracy leftAccuracy = PruneDecisionTree(split.Left); DataNodeAccuracy rightAccuracy = PruneDecisionTree(split.Right); float splitCorrect = leftAccuracy.CorrectWeight + rightAccuracy.CorrectWeight; float splitTotal = leftAccuracy.TotalWeight + rightAccuracy.TotalWeight; float leafCorrect = node.ClassDistribution.Max(); float leafTotal = node.ClassDistribution.Sum(); #if DEBUGCUDA Assert.AreEqual(leafTotal, splitTotal, 1.0f); #endif float accuracyIncrease = splitCorrect - leafCorrect; float requiredIncrease = Context.TotalWeight * GPUConstants.RequiredImprovementToSplit; if (accuracyIncrease < requiredIncrease) { // Split not justified, revert back to leaf at this level return(new DataNodeAccuracy { Node = new DataLeaf { ClassDistribution = node.ClassDistribution }, CorrectWeight = leafCorrect, TotalWeight = leafTotal, }); } else { // Take the split with the pruned nodes if (split is AttributeSplit) { AttributeSplit attributeSplit = (AttributeSplit)split; attributeSplit.Left = leftAccuracy.Node; attributeSplit.Right = rightAccuracy.Node; } else if (split is CategoricalSplit) { CategoricalSplit categoricalSplit = (CategoricalSplit)split; categoricalSplit.Left = leftAccuracy.Node; categoricalSplit.Right = rightAccuracy.Node; } else { throw new ArgumentException("Unknown split type: " + split); } return(new DataNodeAccuracy { Node = split, CorrectWeight = splitCorrect, TotalWeight = splitTotal, }); } } else { throw new ArgumentException("Unknown node type: " + node); } }