private static void AppendLeaves(IDataNode node, List <DataLeaf> leaves)
 {
     if (node is DataLeaf)
     {
         leaves.Add((DataLeaf)node);
     }
     else if (node is IDataSplit)
     {
         IDataSplit split = (IDataSplit)node;
         AppendLeaves(split.Left, leaves);
         AppendLeaves(split.Right, leaves);
     }
     else
     {
         throw new ArgumentException("Unknown node type: " + node);
     }
 }
        private DataNodeAccuracy PruneDecisionTree(IDataNode node)
        {
            if (node is DataLeaf)
            {
                return(new DataNodeAccuracy {
                    Node = node,
                    CorrectWeight = node.ClassDistribution.Max(),
                    TotalWeight = node.ClassDistribution.Sum(),
                });
            }
            else if (node is IDataSplit)
            {
                IDataSplit       split         = (IDataSplit)node;
                DataNodeAccuracy leftAccuracy  = PruneDecisionTree(split.Left);
                DataNodeAccuracy rightAccuracy = PruneDecisionTree(split.Right);

                float splitCorrect = leftAccuracy.CorrectWeight + rightAccuracy.CorrectWeight;
                float splitTotal   = leftAccuracy.TotalWeight + rightAccuracy.TotalWeight;

                float leafCorrect = node.ClassDistribution.Max();
                float leafTotal   = node.ClassDistribution.Sum();

#if DEBUGCUDA
                Assert.AreEqual(leafTotal, splitTotal, 1.0f);
#endif

                float accuracyIncrease = splitCorrect - leafCorrect;
                float requiredIncrease = Context.TotalWeight * GPUConstants.RequiredImprovementToSplit;
                if (accuracyIncrease < requiredIncrease)
                {
                    // Split not justified, revert back to leaf at this level
                    return(new DataNodeAccuracy {
                        Node = new DataLeaf {
                            ClassDistribution = node.ClassDistribution
                        },
                        CorrectWeight = leafCorrect,
                        TotalWeight = leafTotal,
                    });
                }
                else
                {
                    // Take the split with the pruned nodes
                    if (split is AttributeSplit)
                    {
                        AttributeSplit attributeSplit = (AttributeSplit)split;
                        attributeSplit.Left  = leftAccuracy.Node;
                        attributeSplit.Right = rightAccuracy.Node;
                    }
                    else if (split is CategoricalSplit)
                    {
                        CategoricalSplit categoricalSplit = (CategoricalSplit)split;
                        categoricalSplit.Left  = leftAccuracy.Node;
                        categoricalSplit.Right = rightAccuracy.Node;
                    }
                    else
                    {
                        throw new ArgumentException("Unknown split type: " + split);
                    }
                    return(new DataNodeAccuracy {
                        Node = split,
                        CorrectWeight = splitCorrect,
                        TotalWeight = splitTotal,
                    });
                }
            }
            else
            {
                throw new ArgumentException("Unknown node type: " + node);
            }
        }