public Tree( int nFeatures, uint[] nClasses, int nOutputs, SplitterBase splitter, uint maxDepth, uint minSamplesSplit, uint minSamplesLeaf) { // Input/Output layout this.nFeatures = nFeatures; this.nOutputs = nOutputs; this.NClasses = new uint[nOutputs]; this.maxNClasses = nClasses.Max(); this.valueStride = (uint)this.nOutputs * this.maxNClasses; for (uint k = 0; k < nOutputs; k++) { this.NClasses[k] = nClasses[k]; } // Parameters this.splitter = splitter; this.maxDepth = maxDepth; this.minSamplesSplit = minSamplesSplit; this.minSamplesLeaf = minSamplesLeaf; // Inner structures this.NodeCount = 0; this.Capacity = 0; this.ChildrenLeft = null; this.ChildrenRight = null; this.Feature = null; this.Threshold = null; this.Value = null; this.Impurity = null; this.NNodeSamples = null; }
/// <summary> /// Build a decision tree from the training set (X, y). /// </summary> public void build(MathNet.Numerics.LinearAlgebra.Generic.Matrix <double> x, MathNet.Numerics.LinearAlgebra.Generic.Matrix <double> y, MathNet.Numerics.LinearAlgebra.Generic.Vector <double> sampleWeight = null) { // Prepare data before recursive partitioning // Initial capacity int initCapacity; if (this.maxDepth <= 10) { initCapacity = (int)Math.Pow(2, (this.maxDepth + 1)) - 1; } else { initCapacity = 2047; } this.Resize(initCapacity); // Recursive partition (without actual recursion) SplitterBase splitter = this.splitter; splitter.init(x, y, sampleWeight == null ? null : sampleWeight.ToArray()); uint stackNValues = 5; uint stackCapacity = 50; uint[] stack = new uint[stackCapacity]; stack[0] = 0; // start stack[1] = splitter.n_samples; // end stack[2] = 0; // depth stack[3] = _TREE_UNDEFINED; // parent stack[4] = 0; // is_left uint pos = 0; uint feature = 0; double threshold = 0; double impurity = 0; while (stackNValues > 0) { stackNValues -= 5; uint start = stack[stackNValues]; uint end = stack[stackNValues + 1]; uint depth = stack[stackNValues + 2]; uint parent = stack[stackNValues + 3]; bool isLeft = stack[stackNValues + 4] != 0; uint nNodeSamples = end - start; bool isLeaf = ((depth >= this.maxDepth) || (nNodeSamples < this.minSamplesSplit) || (nNodeSamples < 2 * this.minSamplesLeaf)); splitter.node_reset(start, end, ref impurity); isLeaf = isLeaf || (impurity < 1e-7); if (!isLeaf) { splitter.node_split(ref pos, ref feature, ref threshold); isLeaf = isLeaf || (pos >= end); } uint nodeId = this.AddNode(parent, isLeft, isLeaf, feature, threshold, impurity, nNodeSamples); if (isLeaf) { // Don't store value for internal nodes splitter.node_value(this.Value, nodeId * this.valueStride); } else { if (stackNValues + 10 > stackCapacity) { stackCapacity *= 2; var newStack = new uint[stackCapacity]; Array.Copy(stack, newStack, stack.Length); stack = newStack; } // Stack right child stack[stackNValues] = pos; stack[stackNValues + 1] = end; stack[stackNValues + 2] = depth + 1; stack[stackNValues + 3] = nodeId; stack[stackNValues + 4] = 0; stackNValues += 5; // Stack left child stack[stackNValues] = start; stack[stackNValues + 1] = pos; stack[stackNValues + 2] = depth + 1; stack[stackNValues + 3] = nodeId; stack[stackNValues + 4] = 1; stackNValues += 5; } } this.Resize((int)this.NodeCount); this.splitter = null; // Release memory }
private void FitCommon( Matrix <double> x, Matrix <double> y, int nSamples, Vector <double> sampleWeight, bool isClassification) { int maxDepth = this.maxDepth ?? int.MaxValue; if (this.maxFeatures == null) { this.MaxFeaturesValue = this.NFeatures; } else { this.MaxFeaturesValue = maxFeatures.ComputeMaxFeatures(this.NFeatures, isClassification); } if (this.minSamplesSplit <= 0) { throw new ArgumentException("min_samples_split must be greater than zero."); } if (this.minSamplesLeaf <= 0) { throw new ArgumentException("minSamplesLeaf must be greater than zero."); } if (maxDepth <= 0) { throw new ArgumentException("maxDepth must be greater than zero. "); } if (!(0 < MaxFeaturesValue && MaxFeaturesValue <= this.NFeatures)) { throw new ArgumentException("maxFeatures must be in (0, n_features]"); } if (sampleWeight != null) { if (sampleWeight.Count != nSamples) { throw new ArgumentException( string.Format( "Number of weights={0} does not match number of samples={1}", sampleWeight.Count, nSamples)); } } // Set min_samples_split sensibly minSamplesSplit = Math.Max(this.minSamplesSplit, 2 * this.minSamplesLeaf); // Build tree ICriterion criterion = null; switch (this.criterion) { case Criterion.Gini: criterion = new Gini((uint)nOutputs, NClasses.ToArray()); break; case Criterion.Entropy: criterion = new Entropy((uint)nOutputs, NClasses.ToArray()); break; case Criterion.Mse: criterion = new MSE((uint)nOutputs); break; default: throw new InvalidOperationException("Unknown criterion type"); } SplitterBase splitter = null; switch (this.splitter) { case Splitter.Best: splitter = new BestSplitter(criterion, (uint)this.MaxFeaturesValue, (uint)this.minSamplesLeaf, randomState); break; case Splitter.PresortBest: splitter = new PresortBestSplitter(criterion, (uint)this.MaxFeaturesValue, (uint)this.minSamplesLeaf, randomState); break; case Splitter.Random: splitter = new RandomSplitter(criterion, (uint)this.MaxFeaturesValue, (uint)this.minSamplesLeaf, randomState); break; default: throw new InvalidOperationException("Unknown splitter type"); } this.Tree = new Tree( this.NFeatures, this.NClasses.ToArray(), this.nOutputs, splitter, (uint)maxDepth, (uint)minSamplesSplit, (uint)this.minSamplesLeaf); this.Tree.build(x, y, sampleWeight: sampleWeight); }