Пример #1
0
 private unsafe void AssertMatrixView(IMatrix <double> matrix, F64MatrixView view)
 {
     for (int i = 0; i < matrix.RowCount; i++)
     {
         for (int j = 0; j < matrix.ColumnCount; j++)
         {
             Assert.AreEqual(matrix.At(i, j), view[i][j]);
         }
     }
 }
        /// <summary>
        /// Learns a decision tree from the provided observations and targets but limited to the observation indices provided by indices.
        /// Indices can contain the same index multiple times. Weights can be provided in order to weight each sample individually
        /// </summary>
        /// <param name="observations"></param>
        /// <param name="targets"></param>
        /// <param name="indices"></param>
        /// <param name="weights">Provide weights inorder to weigh each sample separetely</param>
        /// <returns></returns>
        public BinaryTree Learn(F64MatrixView observations, double[] targets, int[] indices, double[] weights)
        {
            Checks.VerifyObservationsAndTargets(observations, targets);
            Checks.VerifyIndices(indices, observations, targets);

            // Verify weights dimensions. Currently sample weights is supported by DecisionTreeLearner.
            // Hence, the check is not added to the general checks.
            if (weights.Length != 0)
            {
                if (weights.Length != targets.Length || weights.Length != observations.RowCount)
                {
                    throw new ArgumentException($"Weights length differ from observation row count and target length. Weights: {weights.Length}, observation: {observations.RowCount}, targets: {targets.Length}");
                }
            }
            return(m_treeBuilder.Build(observations, targets, indices, weights));
        }
 /// <summary>
 /// Learns a regression tree from the provided observations and targets but limited to the observation indices provided by indices.
 /// Indices can contain the same index multiple times.
 /// </summary>
 /// <param name="observations"></param>
 /// <param name="targets"></param>
 /// <param name="indices"></param>
 /// <returns></returns>
 public new RegressionDecisionTreeModel Learn(F64MatrixView observations, double[] targets,
                                              int[] indices)
 {
     return(new RegressionDecisionTreeModel(base.Learn(observations, targets, indices)));
 }
        /// <summary>
        ///
        /// </summary>
        /// <param name="observations"></param>
        /// <param name="targets"></param>
        /// <param name="indices"></param>
        /// <param name="weights"></param>
        /// <returns></returns>
        public BinaryTree Build(F64MatrixView observations, double[] targets, int[] indices, double[] weights)
        {
            Array.Clear(m_variableImportance, 0, m_variableImportance.Length);

            Array.Resize(ref m_workTargets, indices.Length);
            Array.Resize(ref m_workFeature, indices.Length);
            Array.Resize(ref m_workIndices, indices.Length);

            var numberOfFeatures = observations.ColumnCount;

            if (m_featuresPrSplit == 0)
            {
                m_featuresPrSplit = numberOfFeatures;
            }

            Array.Resize(ref m_bestSplitWorkIndices, indices.Length);
            m_bestSplitWorkIndices.Clear();
            Array.Resize(ref m_variableImportance, numberOfFeatures);
            Array.Resize(ref m_allFeatureIndices, numberOfFeatures);
            Array.Resize(ref m_featureCandidates, m_featuresPrSplit);

            m_featuresCandidatesSet = false;

            for (int i = 0; i < m_allFeatureIndices.Length; i++)
            {
                m_allFeatureIndices[i] = i;
            }

            var allInterval = Interval1D.Create(0, indices.Length);

            indices.CopyTo(allInterval, m_workIndices);
            m_workIndices.IndexedCopy(targets, allInterval, m_workTargets);

            if (weights.Length != 0)
            {
                Array.Resize(ref m_workWeights, indices.Length);
                m_workIndices.IndexedCopy(weights, allInterval, m_workWeights);
            }

            var targetNames = targets.Distinct().ToArray();

            m_impurityCalculator.Init(targetNames, m_workTargets, m_workWeights, allInterval);
            var rootImpurity = m_impurityCalculator.NodeImpurity();

            var nodes         = new List <Node>();
            var probabilities = new List <double[]>();

            var stack = new Stack <DecisionNodeCreationItem>(100);

            stack.Push(new DecisionNodeCreationItem(0, NodePositionType.Root, allInterval, rootImpurity, 0));

            var first                       = true;
            var currentNodeIndex            = 0;
            var currentLeafProbabilityIndex = 0;

            while (stack.Count > 0)
            {
                var bestSplitResult  = SplitResult.Initial();
                var bestFeatureIndex = -1;
                var parentItem       = stack.Pop();

                var  parentInterval  = parentItem.Interval;
                var  parentNodeDepth = parentItem.NodeDepth;
                Node parentNode      = Node.Default();

                if (nodes.Count != 0)
                {
                    parentNode = nodes[parentItem.ParentIndex];
                }

                var parentNodePositionType = parentItem.NodeType;
                var parentImpurity         = parentItem.Impurity;

                if (first && parentNode.FeatureIndex != -1)
                {
                    nodes[0] = new Node(parentNode.FeatureIndex,
                                        parentNode.Value, -1, -1, parentNode.NodeIndex, parentNode.LeafProbabilityIndex);

                    first = false;
                }

                var isLeaf = (parentNodeDepth >= m_maximumTreeDepth);

                if (!isLeaf)
                {
                    SetNextFeatures(numberOfFeatures);

                    foreach (var featureIndex in m_featureCandidates)
                    {
                        m_workIndices.IndexedCopy(observations.ColumnView(featureIndex), parentInterval, m_workFeature);
                        m_workFeature.SortWith(parentInterval, m_workIndices);
                        m_workIndices.IndexedCopy(targets, parentInterval, m_workTargets);

                        if (weights.Length != 0)
                        {
                            m_workIndices.IndexedCopy(weights, parentInterval, m_workWeights);
                        }

                        var splitResult = m_splitSearcher.FindBestSplit(m_impurityCalculator, m_workFeature,
                                                                        m_workTargets, parentInterval, parentImpurity);

                        if (splitResult.ImpurityImprovement > bestSplitResult.ImpurityImprovement)
                        {
                            bestSplitResult = splitResult;
                            m_workIndices.CopyTo(parentInterval, m_bestSplitWorkIndices);
                            bestFeatureIndex = featureIndex;
                        }
                    }

                    isLeaf = isLeaf || (bestSplitResult.SplitIndex < 0);
                    isLeaf = isLeaf || (bestSplitResult.ImpurityImprovement < m_minimumInformationGain);

                    m_bestSplitWorkIndices.CopyTo(parentInterval, m_workIndices);
                }

                if (isLeaf)
                {
                    m_bestSplitWorkIndices.IndexedCopy(targets, parentInterval, m_workTargets);

                    if (weights.Length != 0)
                    {
                        m_bestSplitWorkIndices.IndexedCopy(weights, parentInterval, m_workWeights);
                    }

                    m_impurityCalculator.UpdateInterval(parentInterval);
                    var value = m_impurityCalculator.LeafValue();

                    var leaf = new Node(-1, value, -1, -1,
                                        currentNodeIndex++, currentLeafProbabilityIndex++);

                    probabilities.Add(m_impurityCalculator.LeafProbabilities());

                    nodes.Add(leaf);
                    nodes.UpdateParent(parentNode, leaf, parentNodePositionType);
                }
                else
                {
                    m_variableImportance[bestFeatureIndex] += bestSplitResult.ImpurityImprovement * parentInterval.Length / allInterval.Length;

                    var split = new Node(bestFeatureIndex, bestSplitResult.Threshold, -1, -1,
                                         currentNodeIndex++, -1);

                    nodes.Add(split);
                    nodes.UpdateParent(parentNode, split, parentNodePositionType);

                    var nodeDepth = parentNodeDepth + 1;

                    stack.Push(new DecisionNodeCreationItem(split.NodeIndex, NodePositionType.Right,
                                                            Interval1D.Create(bestSplitResult.SplitIndex, parentInterval.ToExclusive),
                                                            bestSplitResult.ImpurityRight, nodeDepth));

                    stack.Push(new DecisionNodeCreationItem(split.NodeIndex, NodePositionType.Left,
                                                            Interval1D.Create(parentInterval.FromInclusive, bestSplitResult.SplitIndex),
                                                            bestSplitResult.ImpurityLeft, nodeDepth));
                }
            }

            if (first) // No valid split return single leaf result
            {
                m_impurityCalculator.UpdateInterval(allInterval);

                var leaf = new Node(-1, m_impurityCalculator.LeafValue(), -1, -1,
                                    currentNodeIndex++, currentLeafProbabilityIndex++);

                probabilities.Add(m_impurityCalculator.LeafProbabilities());

                nodes.Clear();
                nodes.Add(leaf);
            }

            return(new BinaryTree(nodes, probabilities, targetNames,
                                  m_variableImportance.ToArray()));
        }
 /// <summary>
 /// Learns a decision tree from the provided observations and targets but limited to the observation indices provided by indices.
 /// Indices can contain the same index multiple times.
 /// </summary>
 /// <param name="observations"></param>
 /// <param name="targets"></param>
 /// <param name="indices"></param>
 /// <returns></returns>
 public BinaryTree Learn(F64MatrixView observations, double[] targets, int[] indices)
 {
     return(Learn(observations, targets, indices, new double[0]));
 }
Пример #6
0
 /// <summary>
 /// Verify that indices are valid and match observations and targets.
 /// </summary>
 /// <param name="indices"></param>
 /// <param name="observations"></param>
 /// <param name="targets"></param>
 public static void VerifyIndices(int[] indices, F64MatrixView observations, double[] targets)
 {
     VerifyIndices(indices, observations.RowCount, targets.Length);
 }
Пример #7
0
 /// <summary>
 /// Verify that observations and targets are valid.
 /// </summary>
 /// <param name="observations"></param>
 /// <param name="targets"></param>
 public static void VerifyObservationsAndTargets(F64MatrixView observations, double[] targets)
 {
     VerifyObservationsAndTargets(observations.RowCount, observations.ColumnCount, targets.Length);
 }
 /// <summary>
 /// Learns a classification tree from the provided observations and targets but limited to the observation indices provided by indices.
 /// Indices can contain the same index multiple times.
 /// </summary>
 /// <param name="observations"></param>
 /// <param name="targets"></param>
 /// <param name="indices"></param>
 /// <param name="weights"></param>
 /// <returns></returns>
 public new ClassificationDecisionTreeModel Learn(F64MatrixView observations, double[] targets, int[] indices, double[] weights)
 {
     return(new ClassificationDecisionTreeModel(base.Learn(observations, targets, indices, weights)));
 }
 /// <summary>
 /// Learns a decision tree from the provided observations and targets but limited to the observation indices provided by indices.
 /// Indices can contain the same index multiple times. Weights can be provided in order to weight each sample individually
 /// </summary>
 /// <param name="observations"></param>
 /// <param name="targets"></param>
 /// <param name="indices"></param>
 /// <param name="weights">Provide weights inorder to weigh each sample separetely</param>
 /// <returns></returns>
 public BinaryTree Learn(F64MatrixView observations, double[] targets, int[] indices, double[] weights)
 {
     return(m_treeBuilder.Build(observations, targets, indices, weights));
 }