/// <summary> /// Train a new decision tree given some training data and a training /// problem described by an ITrainingContext instance. /// </summary> /// <param name="random">The single random number generator.</param> /// <param name="progress">Progress reporting target.</param> /// <param name="context">The ITrainingContext instance by which /// the training framework interacts with the training data. /// Implemented within client code.</param> /// <param name="parameters">Training parameters.</param> /// <param name="maxThreads">The maximum number of threads to use.</param> /// <param name="data">The training data.</param> /// <returns>A new decision tree.</returns> static public Tree <F, S> TrainTree( Random random, ITrainingContext <F, S> context, TrainingParameters parameters, int maxThreads, IDataPointCollection data, ProgressWriter progress = null) { if (progress == null) { progress = new ProgressWriter(Verbosity.Interest, Console.Out); } ParallelTreeTrainingOperation <F, S> trainingOperation = new ParallelTreeTrainingOperation <F, S>( random, context, parameters, maxThreads, data, progress); Tree <F, S> tree = new Tree <F, S>(parameters.MaxDecisionLevels); trainingOperation.TrainNodesRecurse(tree.nodes_, 0, 0, data.Count(), 0); // will recurse until termination criterion is met tree.CheckValid(); return(tree); }
public void TrainNodesRecurse(Node <F, S>[] nodes, int nodeIndex, int i0, int i1, int recurseDepth) { System.Diagnostics.Debug.Assert(nodeIndex < nodes.Length); nodes[nodeIndex] = new Node <F, S>(); progress_.Write(Verbosity.Verbose, "{0}{1}: ", Tree <F, S> .GetPrettyPrintPrefix(nodeIndex), i1 - i0); // First aggregate statistics over the samples at the parent node parentStatistics_.Clear(); for (int i = i0; i < i1; i++) { parentStatistics_.Aggregate(data_, indices_[i]); } // Copy parent statistics to thread local storage in case client // IStatisticsAggregator implementations are not reentrant. for (int t = 0; t < maxThreads_; t++) { threadLocals_[t].parentStatistics_ = parentStatistics_.DeepClone(); } if (nodeIndex >= nodes.Length / 2) // this is a leaf node, nothing else to do { nodes[nodeIndex].InitializeLeaf(parentStatistics_); progress_.WriteLine(Verbosity.Verbose, "Terminating at max depth."); return; } // Iterate over threads Parallel.For(0, maxThreads_, new Action <int>(threadIndex => { ThreadLocalData tl = threadLocals_[threadIndex]; // shorthand tl.Clear(); // Range of indices of candidate feature evaluated in this thread // (if the number of candidate features is not a multiple of the // number of threads, some threads may evaluate one more feature // than others). int f1 = (int)(Math.Round(parameters_.NumberOfCandidateFeatures * (double)threadIndex / (double)maxThreads_)); int f2 = (int)(Math.Round(parameters_.NumberOfCandidateFeatures * (double)(threadIndex + 1) / (double)maxThreads_)); // Iterate over candidate features for (int f = f1; f < f2; f++) { F feature = trainingContext_.GetRandomFeature(tl.random_); for (int b = 0; b < parameters_.NumberOfCandidateThresholdsPerFeature + 1; b++) { threadLocals_[threadIndex].partitionStatistics_[b].Clear(); // reset statistics } // Compute feature response per sample at this node for (int i = i0; i < i1; i++) { tl.responses_[i] = feature.GetResponse(data_, indices_[i]); } int nThresholds; if ((nThresholds = ParallelTreeTrainingOperation <F, S> .ChooseCandidateThresholds( tl.random_, indices_, i0, i1, tl.responses_, parameters_.NumberOfCandidateThresholdsPerFeature, ref tl.thresholds)) == 0) { continue; // failed to find meaningful candidate thresholds for this feature } // Aggregate statistics over sample partitions for (int i = i0; i < i1; i++) { // Slightly faster than List<float>.BinarySearch() for < O(100) thresholds int b = 0; while (b < nThresholds && tl.responses_[i] >= tl.thresholds[b]) { b++; } tl.partitionStatistics_[b].Aggregate(data_, indices_[i]); } for (int t = 0; t < nThresholds; t++) { tl.leftChildStatistics_.Clear(); tl.rightChildStatistics_.Clear(); for (int p = 0; p < nThresholds + 1 /*i.e. nBins*/; p++) { if (p <= t) { tl.leftChildStatistics_.Aggregate(tl.partitionStatistics_[p]); } else { tl.rightChildStatistics_.Aggregate(tl.partitionStatistics_[p]); } } // Compute gain over sample partitions double gain = trainingContext_.ComputeInformationGain(tl.parentStatistics_, tl.leftChildStatistics_, tl.rightChildStatistics_); if (gain >= tl.maxGain) { tl.maxGain = gain; tl.bestFeature = feature; tl.bestThreshold = tl.thresholds[t]; } } } })); double maxGain = 0.0; F bestFeature = default(F); float bestThreshold = 0.0f; // Now merge over threads for (int threadIndex = 0; threadIndex < maxThreads_; threadIndex++) { ThreadLocalData tl = threadLocals_[threadIndex]; if (tl.maxGain > maxGain) { maxGain = tl.maxGain; bestFeature = tl.bestFeature; bestThreshold = tl.bestThreshold; } } if (maxGain == 0.0) { nodes[nodeIndex].InitializeLeaf(parentStatistics_); progress_.WriteLine(Verbosity.Verbose, "Terminating with no beneficial split found."); return; } // Now reorder the data point indices using the winning feature and thresholds. // Also recompute child node statistics so the client can decide whether // to terminate training of this branch. leftChildStatistics_.Clear(); rightChildStatistics_.Clear(); for (int i = i0; i < i1; i++) { responses_[i] = bestFeature.GetResponse(data_, indices_[i]); if (responses_[i] < bestThreshold) { leftChildStatistics_.Aggregate(data_, indices_[i]); } else { rightChildStatistics_.Aggregate(data_, indices_[i]); } } if (trainingContext_.ShouldTerminate(parentStatistics_, leftChildStatistics_, rightChildStatistics_, maxGain)) { nodes[nodeIndex].InitializeLeaf(parentStatistics_); progress_.WriteLine(Verbosity.Verbose, "Terminating because supplied termination criterion is satisfied."); return; } // Otherwise this is a new split node, recurse for children. nodes[nodeIndex].InitializeSplit( bestFeature, bestThreshold, parentStatistics_.DeepClone()); // Now do partition sort - any sample with response greater goes left, otherwise right int ii = Tree <F, S> .Partition(responses_, indices_, i0, i1, bestThreshold); System.Diagnostics.Debug.Assert(ii >= i0 && i1 >= ii); progress_.WriteLine(Verbosity.Verbose, "{0} (threshold = {1:G3}, gain = {2:G3}).", bestFeature.ToString(), bestThreshold, maxGain); TrainNodesRecurse(nodes, nodeIndex * 2 + 1, i0, ii, recurseDepth + 1); TrainNodesRecurse(nodes, nodeIndex * 2 + 2, ii, i1, recurseDepth + 1); }