private DecisionNodeAccuracy CalculateDecisionTree(List <Episode> episodes)
        {
            List <DataPoint> dataPoints = EpisodesToDataPoints(episodes);

            using (DecisionLearner decisionLearner = new DecisionLearner(_cudaManager, dataPoints)) {
                DataNodeAccuracy rootAccuracy = decisionLearner.FitDecisionTree();
                double           gpuAccuracy  = rootAccuracy.Accuracy;

                IDecisionNode decisionTree = NodeToDecisionTree(rootAccuracy.Node);

#if DEBUGCUDA
                SanityCheckDecisionTree(rootAccuracy.Node, decisionTree, gpuAccuracy, episodes);
#endif

                return(new DecisionNodeAccuracy {
                    Root = decisionTree,
                    Accuracy = gpuAccuracy,
                });
            }
        }
        public DataNodeAccuracy FitDecisionTree()
        {
            Initialize();

            Kernels["dlInitDataPoints"].Execute(Context.DataPoints.Count);
            Kernels["dlInitialNode"].ExecuteTask();
            for (int i = 0; i < GPUConstants.MaxLevels; ++i)
            {
                Kernels["dlLeafOpenNodes"]
                .Execute(NumSmallBlocks * GPUConstants.NumThreadsPerBlock, GPUConstants.NumThreadsPerBlock);
                bool isFinalLevel = i == GPUConstants.MaxLevels - 1;
                if (isFinalLevel)
                {
                    break;                     // Can't split anymore
                }

                int numSplits = 0;

                numSplits += _attributeSplitter.FindOptimalSplit(_allSplits, numSplits);
                numSplits += _categoricalSplitter.FindOptimalSplit(_allSplits, numSplits);

                Kernels["dlFindOptimalSplitPerNode"]
                .Arguments(_allSplits, numSplits, _bestSplits)
                .Execute(NumSmallBlocks * GPUConstants.NumThreadsPerBlock, GPUConstants.NumThreadsPerBlock);

                _attributeSplitter.ApplyOptimalSplit(_bestSplits);
                _categoricalSplitter.ApplyOptimalSplit(_bestSplits);

                Kernels["dlNextLevel"]
                .Execute(NumSmallBlocks * GPUConstants.NumThreadsPerBlock, GPUConstants.NumThreadsPerBlock);
            }

            GPUNode[] gpuNodes = Context.Nodes.Read();
#if DEBUGCUDA
            SanityCheckGPUNodes(gpuNodes);
#endif
            IDataNode        unpruned = ReadDecisionTree(gpuNodes, 0);
            DataNodeAccuracy root     = PruneDecisionTree(unpruned);
            return(root);
        }
        private DataNodeAccuracy PruneDecisionTree(IDataNode node)
        {
            if (node is DataLeaf)
            {
                return(new DataNodeAccuracy {
                    Node = node,
                    CorrectWeight = node.ClassDistribution.Max(),
                    TotalWeight = node.ClassDistribution.Sum(),
                });
            }
            else if (node is IDataSplit)
            {
                IDataSplit       split         = (IDataSplit)node;
                DataNodeAccuracy leftAccuracy  = PruneDecisionTree(split.Left);
                DataNodeAccuracy rightAccuracy = PruneDecisionTree(split.Right);

                float splitCorrect = leftAccuracy.CorrectWeight + rightAccuracy.CorrectWeight;
                float splitTotal   = leftAccuracy.TotalWeight + rightAccuracy.TotalWeight;

                float leafCorrect = node.ClassDistribution.Max();
                float leafTotal   = node.ClassDistribution.Sum();

#if DEBUGCUDA
                Assert.AreEqual(leafTotal, splitTotal, 1.0f);
#endif

                float accuracyIncrease = splitCorrect - leafCorrect;
                float requiredIncrease = Context.TotalWeight * GPUConstants.RequiredImprovementToSplit;
                if (accuracyIncrease < requiredIncrease)
                {
                    // Split not justified, revert back to leaf at this level
                    return(new DataNodeAccuracy {
                        Node = new DataLeaf {
                            ClassDistribution = node.ClassDistribution
                        },
                        CorrectWeight = leafCorrect,
                        TotalWeight = leafTotal,
                    });
                }
                else
                {
                    // Take the split with the pruned nodes
                    if (split is AttributeSplit)
                    {
                        AttributeSplit attributeSplit = (AttributeSplit)split;
                        attributeSplit.Left  = leftAccuracy.Node;
                        attributeSplit.Right = rightAccuracy.Node;
                    }
                    else if (split is CategoricalSplit)
                    {
                        CategoricalSplit categoricalSplit = (CategoricalSplit)split;
                        categoricalSplit.Left  = leftAccuracy.Node;
                        categoricalSplit.Right = rightAccuracy.Node;
                    }
                    else
                    {
                        throw new ArgumentException("Unknown split type: " + split);
                    }
                    return(new DataNodeAccuracy {
                        Node = split,
                        CorrectWeight = splitCorrect,
                        TotalWeight = splitTotal,
                    });
                }
            }
            else
            {
                throw new ArgumentException("Unknown node type: " + node);
            }
        }