Exemplo n.º 1
0
        /// <summary>
        /// Searches for the best split using a brute force approach on all unique threshold values.
        /// The implementation assumes that the features and targets have been sorted
        /// together using the features as sort criteria
        /// </summary>
        /// <param name="impurityCalculator"></param>
        /// <param name="feature"></param>
        /// <param name="targets"></param>
        /// <param name="parentInterval"></param>
        /// <param name="parentImpurity"></param>
        /// <returns></returns>
        public SplitResult FindBestSplit(IImpurityCalculator impurityCalculator, double[] feature, double[] targets,
                                         Interval1D parentInterval, double parentImpurity)
        {
            var bestSplitIndex          = -1;
            var bestThreshold           = 0.0;
            var bestImpurityImprovement = 0.0;
            var bestImpurityLeft        = 0.0;
            var bestImpurityRight       = 0.0;

            int prevSplit  = parentInterval.FromInclusive;
            var prevValue  = feature[prevSplit];
            var prevTarget = targets[prevSplit];

            impurityCalculator.UpdateInterval(parentInterval);

            for (int j = prevSplit + 1; j < parentInterval.ToExclusive; j++)
            {
                var currentValue  = feature[j];
                var currentTarget = targets[j];
                if (prevValue != currentValue)
                {
                    var currentSplit = j;
                    var leftSize     = (double)(currentSplit - parentInterval.FromInclusive);
                    var rightSize    = (double)(parentInterval.ToExclusive - currentSplit);

                    if (Math.Min(leftSize, rightSize) >= m_minimumSplitSize)
                    {
                        impurityCalculator.UpdateIndex(currentSplit);

                        if ((impurityCalculator.WeightedLeft < m_minimumLeafWeight) ||
                            (impurityCalculator.WeightedRight < m_minimumLeafWeight))
                        {
                            continue;
                        }

                        var improvement = impurityCalculator.ImpurityImprovement(parentImpurity);

                        if (improvement > bestImpurityImprovement)
                        {
                            var childImpurities = impurityCalculator.ChildImpurities(); // could be avoided

                            bestImpurityImprovement = improvement;
                            bestThreshold           = (currentValue + prevValue) * 0.5;
                            bestSplitIndex          = currentSplit;
                            bestImpurityLeft        = childImpurities.Left;
                            bestImpurityRight       = childImpurities.Right;
                        }

                        prevSplit = j;
                    }
                }

                prevValue  = currentValue;
                prevTarget = currentTarget;
            }

            return(new SplitResult(bestSplitIndex, bestThreshold,
                                   bestImpurityImprovement, bestImpurityLeft, bestImpurityRight));
        }
        /// <summary>
        ///
        /// </summary>
        /// <param name="observations"></param>
        /// <param name="targets"></param>
        /// <param name="indices"></param>
        /// <param name="weights"></param>
        /// <returns></returns>
        public BinaryTree Build(F64MatrixView observations, double[] targets, int[] indices, double[] weights)
        {
            Array.Clear(m_variableImportance, 0, m_variableImportance.Length);

            Array.Resize(ref m_workTargets, indices.Length);
            Array.Resize(ref m_workFeature, indices.Length);
            Array.Resize(ref m_workIndices, indices.Length);

            var numberOfFeatures = observations.ColumnCount;

            if (m_featuresPrSplit == 0)
            {
                m_featuresPrSplit = numberOfFeatures;
            }

            Array.Resize(ref m_bestSplitWorkIndices, indices.Length);
            m_bestSplitWorkIndices.Clear();
            Array.Resize(ref m_variableImportance, numberOfFeatures);
            Array.Resize(ref m_allFeatureIndices, numberOfFeatures);
            Array.Resize(ref m_featureCandidates, m_featuresPrSplit);

            m_featuresCandidatesSet = false;

            for (int i = 0; i < m_allFeatureIndices.Length; i++)
            {
                m_allFeatureIndices[i] = i;
            }

            var allInterval = Interval1D.Create(0, indices.Length);

            indices.CopyTo(allInterval, m_workIndices);
            m_workIndices.IndexedCopy(targets, allInterval, m_workTargets);

            if (weights.Length != 0)
            {
                Array.Resize(ref m_workWeights, indices.Length);
                m_workIndices.IndexedCopy(weights, allInterval, m_workWeights);
            }

            var targetNames = targets.Distinct().ToArray();

            m_impurityCalculator.Init(targetNames, m_workTargets, m_workWeights, allInterval);
            var rootImpurity = m_impurityCalculator.NodeImpurity();

            var nodes         = new List <Node>();
            var probabilities = new List <double[]>();

            var stack = new Stack <DecisionNodeCreationItem>(100);

            stack.Push(new DecisionNodeCreationItem(0, NodePositionType.Root, allInterval, rootImpurity, 0));

            var first                       = true;
            var currentNodeIndex            = 0;
            var currentLeafProbabilityIndex = 0;

            while (stack.Count > 0)
            {
                var bestSplitResult  = SplitResult.Initial();
                var bestFeatureIndex = -1;
                var parentItem       = stack.Pop();

                var  parentInterval  = parentItem.Interval;
                var  parentNodeDepth = parentItem.NodeDepth;
                Node parentNode      = Node.Default();

                if (nodes.Count != 0)
                {
                    parentNode = nodes[parentItem.ParentIndex];
                }

                var parentNodePositionType = parentItem.NodeType;
                var parentImpurity         = parentItem.Impurity;

                if (first && parentNode.FeatureIndex != -1)
                {
                    nodes[0] = new Node(parentNode.FeatureIndex,
                                        parentNode.Value, -1, -1, parentNode.NodeIndex, parentNode.LeafProbabilityIndex);

                    first = false;
                }

                var isLeaf = (parentNodeDepth >= m_maximumTreeDepth);

                if (!isLeaf)
                {
                    SetNextFeatures(numberOfFeatures);

                    foreach (var featureIndex in m_featureCandidates)
                    {
                        m_workIndices.IndexedCopy(observations.ColumnView(featureIndex), parentInterval, m_workFeature);
                        m_workFeature.SortWith(parentInterval, m_workIndices);
                        m_workIndices.IndexedCopy(targets, parentInterval, m_workTargets);

                        if (weights.Length != 0)
                        {
                            m_workIndices.IndexedCopy(weights, parentInterval, m_workWeights);
                        }

                        var splitResult = m_splitSearcher.FindBestSplit(m_impurityCalculator, m_workFeature,
                                                                        m_workTargets, parentInterval, parentImpurity);

                        if (splitResult.ImpurityImprovement > bestSplitResult.ImpurityImprovement)
                        {
                            bestSplitResult = splitResult;
                            m_workIndices.CopyTo(parentInterval, m_bestSplitWorkIndices);
                            bestFeatureIndex = featureIndex;
                        }
                    }

                    isLeaf = isLeaf || (bestSplitResult.SplitIndex < 0);
                    isLeaf = isLeaf || (bestSplitResult.ImpurityImprovement < m_minimumInformationGain);

                    m_bestSplitWorkIndices.CopyTo(parentInterval, m_workIndices);
                }

                if (isLeaf)
                {
                    m_bestSplitWorkIndices.IndexedCopy(targets, parentInterval, m_workTargets);

                    if (weights.Length != 0)
                    {
                        m_bestSplitWorkIndices.IndexedCopy(weights, parentInterval, m_workWeights);
                    }

                    m_impurityCalculator.UpdateInterval(parentInterval);
                    var value = m_impurityCalculator.LeafValue();

                    var leaf = new Node(-1, value, -1, -1,
                                        currentNodeIndex++, currentLeafProbabilityIndex++);

                    probabilities.Add(m_impurityCalculator.LeafProbabilities());

                    nodes.Add(leaf);
                    nodes.UpdateParent(parentNode, leaf, parentNodePositionType);
                }
                else
                {
                    m_variableImportance[bestFeatureIndex] += bestSplitResult.ImpurityImprovement * parentInterval.Length / allInterval.Length;

                    var split = new Node(bestFeatureIndex, bestSplitResult.Threshold, -1, -1,
                                         currentNodeIndex++, -1);

                    nodes.Add(split);
                    nodes.UpdateParent(parentNode, split, parentNodePositionType);

                    var nodeDepth = parentNodeDepth + 1;

                    stack.Push(new DecisionNodeCreationItem(split.NodeIndex, NodePositionType.Right,
                                                            Interval1D.Create(bestSplitResult.SplitIndex, parentInterval.ToExclusive),
                                                            bestSplitResult.ImpurityRight, nodeDepth));

                    stack.Push(new DecisionNodeCreationItem(split.NodeIndex, NodePositionType.Left,
                                                            Interval1D.Create(parentInterval.FromInclusive, bestSplitResult.SplitIndex),
                                                            bestSplitResult.ImpurityLeft, nodeDepth));
                }
            }

            if (first) // No valid split return single leaf result
            {
                m_impurityCalculator.UpdateInterval(allInterval);

                var leaf = new Node(-1, m_impurityCalculator.LeafValue(), -1, -1,
                                    currentNodeIndex++, currentLeafProbabilityIndex++);

                probabilities.Add(m_impurityCalculator.LeafProbabilities());

                nodes.Clear();
                nodes.Add(leaf);
            }

            return(new BinaryTree(nodes, probabilities, targetNames,
                                  m_variableImportance.ToArray()));
        }
        /// <summary>
        ///
        /// </summary>
        /// <param name="impurityCalculator"></param>
        /// <param name="feature"></param>
        /// <param name="targets"></param>
        /// <param name="parentInterval"></param>
        /// <param name="parentImpurity"></param>
        /// <returns></returns>
        public SplitResult FindBestSplit(IImpurityCalculator impurityCalculator, double[] feature, double[] targets, Interval1D parentInterval, double parentImpurity)
        {
            var min = double.MaxValue;
            var max = double.MinValue;

            for (int i = parentInterval.FromInclusive; i < parentInterval.ToExclusive; i++)
            {
                var value = feature[i];

                if (value < min)
                {
                    min = value;
                }
                else if (value > max)
                {
                    max = value;
                }
            }

            if (min == max)
            {
                return(SplitResult.Initial());
            }

            var threshold = RandomThreshold(min, max);

            if (threshold == max)
            {
                threshold = min;
            }

            var splitIndex          = -1;
            var impurityImprovement = 0.0;
            var impurityLeft        = 0.0;
            var impurityRight       = 0.0;

            var currentFeature = double.MinValue;

            for (int i = parentInterval.FromInclusive; i < parentInterval.ToExclusive; i++)
            {
                var leftSize  = (double)(i - parentInterval.FromInclusive);
                var rightSize = (double)(parentInterval.ToExclusive - i);

                currentFeature = feature[i];

                if (currentFeature > threshold && Math.Min(leftSize, rightSize) >= m_minimumSplitSize)
                {
                    splitIndex = i;

                    impurityCalculator.UpdateInterval(parentInterval);
                    impurityCalculator.UpdateIndex(i);
                    impurityImprovement = impurityCalculator.ImpurityImprovement(parentImpurity);

                    var childImpurities = impurityCalculator.ChildImpurities();
                    impurityLeft  = childImpurities.Left;
                    impurityRight = childImpurities.Right;

                    break;
                }
            }

            return(new SplitResult(splitIndex, threshold, impurityImprovement,
                                   impurityLeft, impurityRight));
        }