Exemplo n.º 1
0
        public Tree(
            int nFeatures,
            uint[] nClasses,
            int nOutputs,
            SplitterBase splitter,
            uint maxDepth,
            uint minSamplesSplit,
            uint minSamplesLeaf)
        {
            // Input/Output layout
            this.nFeatures = nFeatures;
            this.nOutputs  = nOutputs;
            this.NClasses  = new uint[nOutputs];


            this.maxNClasses = nClasses.Max();
            this.valueStride = (uint)this.nOutputs * this.maxNClasses;


            for (uint k = 0; k < nOutputs; k++)
            {
                this.NClasses[k] = nClasses[k];
            }

            // Parameters
            this.splitter        = splitter;
            this.maxDepth        = maxDepth;
            this.minSamplesSplit = minSamplesSplit;
            this.minSamplesLeaf  = minSamplesLeaf;

            // Inner structures
            this.NodeCount     = 0;
            this.Capacity      = 0;
            this.ChildrenLeft  = null;
            this.ChildrenRight = null;
            this.Feature       = null;
            this.Threshold     = null;
            this.Value         = null;
            this.Impurity      = null;
            this.NNodeSamples  = null;
        }
Exemplo n.º 2
0
        /// <summary>
        /// Build a decision tree from the training set (X, y).
        /// </summary>
        public void build(MathNet.Numerics.LinearAlgebra.Generic.Matrix <double> x,
                          MathNet.Numerics.LinearAlgebra.Generic.Matrix <double> y,
                          MathNet.Numerics.LinearAlgebra.Generic.Vector <double> sampleWeight = null)
        {
            // Prepare data before recursive partitioning

            // Initial capacity
            int initCapacity;

            if (this.maxDepth <= 10)
            {
                initCapacity = (int)Math.Pow(2, (this.maxDepth + 1)) - 1;
            }
            else
            {
                initCapacity = 2047;
            }

            this.Resize(initCapacity);

            // Recursive partition (without actual recursion)
            SplitterBase splitter = this.splitter;

            splitter.init(x, y, sampleWeight == null ? null : sampleWeight.ToArray());

            uint stackNValues  = 5;
            uint stackCapacity = 50;

            uint[] stack = new uint[stackCapacity];


            stack[0] = 0;                  // start
            stack[1] = splitter.n_samples; // end
            stack[2] = 0;                  // depth
            stack[3] = _TREE_UNDEFINED;    // parent
            stack[4] = 0;                  // is_left

            uint   pos       = 0;
            uint   feature   = 0;
            double threshold = 0;
            double impurity  = 0;

            while (stackNValues > 0)
            {
                stackNValues -= 5;

                uint start  = stack[stackNValues];
                uint end    = stack[stackNValues + 1];
                uint depth  = stack[stackNValues + 2];
                uint parent = stack[stackNValues + 3];
                bool isLeft = stack[stackNValues + 4] != 0;

                uint nNodeSamples = end - start;
                bool isLeaf       = ((depth >= this.maxDepth) ||
                                     (nNodeSamples < this.minSamplesSplit) ||
                                     (nNodeSamples < 2 * this.minSamplesLeaf));

                splitter.node_reset(start, end, ref impurity);
                isLeaf = isLeaf || (impurity < 1e-7);

                if (!isLeaf)
                {
                    splitter.node_split(ref pos, ref feature, ref threshold);
                    isLeaf = isLeaf || (pos >= end);
                }

                uint nodeId = this.AddNode(parent, isLeft, isLeaf, feature,
                                           threshold, impurity, nNodeSamples);

                if (isLeaf)
                {
                    // Don't store value for internal nodes
                    splitter.node_value(this.Value, nodeId * this.valueStride);
                }
                else
                {
                    if (stackNValues + 10 > stackCapacity)
                    {
                        stackCapacity *= 2;
                        var newStack = new uint[stackCapacity];
                        Array.Copy(stack, newStack, stack.Length);
                        stack = newStack;
                    }

                    // Stack right child
                    stack[stackNValues]     = pos;
                    stack[stackNValues + 1] = end;
                    stack[stackNValues + 2] = depth + 1;
                    stack[stackNValues + 3] = nodeId;
                    stack[stackNValues + 4] = 0;
                    stackNValues           += 5;

                    // Stack left child
                    stack[stackNValues]     = start;
                    stack[stackNValues + 1] = pos;
                    stack[stackNValues + 2] = depth + 1;
                    stack[stackNValues + 3] = nodeId;
                    stack[stackNValues + 4] = 1;
                    stackNValues           += 5;
                }
            }

            this.Resize((int)this.NodeCount);
            this.splitter = null; // Release memory
        }
Exemplo n.º 3
0
        private void FitCommon(
            Matrix <double> x,
            Matrix <double> y,
            int nSamples,
            Vector <double> sampleWeight,
            bool isClassification)
        {
            int maxDepth = this.maxDepth ?? int.MaxValue;

            if (this.maxFeatures == null)
            {
                this.MaxFeaturesValue = this.NFeatures;
            }
            else
            {
                this.MaxFeaturesValue = maxFeatures.ComputeMaxFeatures(this.NFeatures, isClassification);
            }

            if (this.minSamplesSplit <= 0)
            {
                throw new ArgumentException("min_samples_split must be greater than zero.");
            }

            if (this.minSamplesLeaf <= 0)
            {
                throw new ArgumentException("minSamplesLeaf must be greater than zero.");
            }

            if (maxDepth <= 0)
            {
                throw new ArgumentException("maxDepth must be greater than zero. ");
            }

            if (!(0 < MaxFeaturesValue && MaxFeaturesValue <= this.NFeatures))
            {
                throw new ArgumentException("maxFeatures must be in (0, n_features]");
            }

            if (sampleWeight != null)
            {
                if (sampleWeight.Count != nSamples)
                {
                    throw new ArgumentException(
                              string.Format(
                                  "Number of weights={0} does not match number of samples={1}",
                                  sampleWeight.Count,
                                  nSamples));
                }
            }

            // Set min_samples_split sensibly
            minSamplesSplit = Math.Max(this.minSamplesSplit, 2 * this.minSamplesLeaf);

            // Build tree
            ICriterion criterion = null;

            switch (this.criterion)
            {
            case Criterion.Gini:
                criterion = new Gini((uint)nOutputs, NClasses.ToArray());
                break;

            case Criterion.Entropy:
                criterion = new Entropy((uint)nOutputs, NClasses.ToArray());
                break;

            case Criterion.Mse:
                criterion = new MSE((uint)nOutputs);
                break;

            default:
                throw new InvalidOperationException("Unknown criterion type");
            }

            SplitterBase splitter = null;

            switch (this.splitter)
            {
            case Splitter.Best:
                splitter = new BestSplitter(criterion, (uint)this.MaxFeaturesValue, (uint)this.minSamplesLeaf, randomState);
                break;

            case Splitter.PresortBest:
                splitter = new PresortBestSplitter(criterion, (uint)this.MaxFeaturesValue, (uint)this.minSamplesLeaf, randomState);
                break;

            case Splitter.Random:
                splitter = new RandomSplitter(criterion, (uint)this.MaxFeaturesValue, (uint)this.minSamplesLeaf, randomState);
                break;

            default:
                throw new InvalidOperationException("Unknown splitter type");
            }

            this.Tree = new Tree(
                this.NFeatures,
                this.NClasses.ToArray(),
                this.nOutputs,
                splitter,
                (uint)maxDepth,
                (uint)minSamplesSplit,
                (uint)this.minSamplesLeaf);

            this.Tree.build(x, y, sampleWeight: sampleWeight);
        }