/// <summary> /// Train a new decision forest given some training data and a training /// problem described by an instance of the ITrainingContext interface. /// </summary> /// <param name="random">Random number generator.</param> /// <param name="parameters">Training parameters.</param> /// <param name="maxThreads">The maximum number of threads to use.</param> /// <param name="context">An ITrainingContext instance describing /// the training problem, e.g. classification, density estimation, etc. </param> /// <param name="data">The training data.</param> /// <returns>A new decision forest.</returns> public static Forest <F, S> TrainForest( Random random, TrainingParameters parameters, ITrainingContext <F, S> context, int maxThreads, IDataPointCollection data, ProgressWriter progress = null) { if (progress == null) { progress = new ProgressWriter(parameters.Verbose?Verbosity.Verbose:Verbosity.Interest, Console.Out); } Forest <F, S> forest = new Forest <F, S>(); for (int t = 0; t < parameters.NumberOfTrees; t++) { progress.Write(Verbosity.Interest, "\rTraining tree {0}...", t); Tree <F, S> tree = ParallelTreeTrainer <F, S> .TrainTree(random, context, parameters, maxThreads, data, progress); forest.AddTree(tree); } progress.WriteLine(Verbosity.Interest, "\rTrained {0} trees. ", parameters.NumberOfTrees); return(forest); }
static public Forest <AxisAlignedFeatureResponse, GaussianAggregator2d> Train( DataPointCollection trainingData, TrainingParameters parameters, double a, double b) { if (trainingData.Dimensions != 2) { throw new Exception("Training data points for density estimation were not 2D."); } if (trainingData.HasLabels == true) { throw new Exception("Density estimation training data should not be labelled."); } if (trainingData.HasTargetValues == true) { throw new Exception("Training data should not have target values."); } // Train the forest Console.WriteLine("Training the forest..."); Random random = new Random(); ITrainingContext <AxisAlignedFeatureResponse, GaussianAggregator2d> densityEstimationTrainingContext = new DensityEstimationTrainingContext(a, b); var forest = ForestTrainer <AxisAlignedFeatureResponse, GaussianAggregator2d> .TrainForest( random, parameters, densityEstimationTrainingContext, trainingData); return(forest); }
static public Forest <F, HistogramAggregator> Train <F>( DataPointCollection trainingData, IFeatureFactory <F> featureFactory, TrainingParameters TrainingParameters) where F : IFeatureResponse { if (trainingData.Dimensions != 2) { throw new Exception("Training data points must be 2D."); } if (trainingData.HasLabels == false) { throw new Exception("Training data points must be labelled."); } if (trainingData.HasTargetValues == true) { throw new Exception("Training data points should not have target values."); } Console.WriteLine("Running training..."); Random random = new Random(); ITrainingContext <F, HistogramAggregator> classificationContext = new ClassificationTrainingContext <F>(trainingData.CountClasses(), featureFactory, random); var forest = ForestTrainer <F, HistogramAggregator> .TrainForest( random, TrainingParameters, classificationContext, trainingData); return(forest); }
public TreeTrainingOperation( Random randomNumberGenerator, ITrainingContext <F, S> trainingContext, TrainingParameters parameters, IDataPointCollection data, ProgressWriter progress) { data_ = data; trainingContext_ = trainingContext; parameters_ = parameters; random_ = randomNumberGenerator; progress_ = progress; indices_ = new int[data.Count()]; for (int i = 0; i < indices_.Length; i++) { indices_[i] = i; } responses_ = new float[data.Count()]; parentStatistics_ = trainingContext_.GetStatisticsAggregator(); leftChildStatistics_ = trainingContext_.GetStatisticsAggregator(); rightChildStatistics_ = trainingContext_.GetStatisticsAggregator(); partitionStatistics_ = new S[parameters.NumberOfCandidateThresholdsPerFeature + 1]; for (int i = 0; i < parameters.NumberOfCandidateThresholdsPerFeature + 1; i++) { partitionStatistics_[i] = trainingContext_.GetStatisticsAggregator(); } }
/// <summary> /// Train a new decision tree given some training data and a training /// problem described by an ITrainingContext instance. /// </summary> /// <param name="random">The single random number generator.</param> /// <param name="progress">Progress reporting target.</param> /// <param name="context">The ITrainingContext instance by which /// the training framework interacts with the training data. /// Implemented within client code.</param> /// <param name="parameters">Training parameters.</param> /// <param name="data">The training data.</param> /// <returns>A new decision tree.</returns> static public Tree <F, S> TrainTree( Random random, ITrainingContext <F, S> context, TrainingParameters parameters, IDataPointCollection data, ProgressWriter progress = null) { if (progress == null) { progress = new ProgressWriter(Verbosity.Interest, Console.Out); } TreeTrainingOperation <F, S> trainingOperation = new TreeTrainingOperation <F, S>( random, context, parameters, data, progress); Tree <F, S> tree = new Tree <F, S>(parameters.MaxDecisionLevels); progress.WriteLine(Verbosity.Verbose, ""); trainingOperation.TrainNodesRecurse(tree.nodes_, 0, 0, data.Count(), 0); // will recurse until termination criterion is met progress.WriteLine(Verbosity.Verbose, ""); tree.CheckValid(); return(tree); }
public ParallelTreeTrainingOperation( Random random, ITrainingContext <F, S> trainingContext, TrainingParameters parameters, int maxThreads, IDataPointCollection data, ProgressWriter progress) { data_ = data; trainingContext_ = trainingContext; parameters_ = parameters; maxThreads_ = maxThreads; random_ = random; progress_ = progress; parentStatistics_ = trainingContext_.GetStatisticsAggregator(); leftChildStatistics_ = trainingContext_.GetStatisticsAggregator(); rightChildStatistics_ = trainingContext_.GetStatisticsAggregator(); responses_ = new float[data.Count()]; indices_ = new int[data.Count()]; for (int i = 0; i < indices_.Length; i++) { indices_[i] = i; } threadLocals_ = new ThreadLocalData[maxThreads_]; for (int threadIndex = 0; threadIndex < maxThreads_; threadIndex++) { threadLocals_[threadIndex] = new ThreadLocalData(random_, trainingContext_, parameters_, data_); } }
public ThreadLocalData(Random random, ITrainingContext <F, S> trainingContext_, TrainingParameters parameters, IDataPointCollection data) { parentStatistics_ = trainingContext_.GetStatisticsAggregator(); leftChildStatistics_ = trainingContext_.GetStatisticsAggregator(); rightChildStatistics_ = trainingContext_.GetStatisticsAggregator(); partitionStatistics_ = new S[parameters.NumberOfCandidateThresholdsPerFeature + 1]; for (int i = 0; i < parameters.NumberOfCandidateThresholdsPerFeature + 1; i++) { partitionStatistics_[i] = trainingContext_.GetStatisticsAggregator(); } responses_ = new float[data.Count()]; random_ = new Random(random.Next()); }
static void Main(string[] args) { if (args.Length == 0 || args[0] == "/?" || args[0].ToLower() == "help") { DisplayHelp(); return; } // These command line parameters are reused over several command line modes... StringParameter trainingDataPath = new StringParameter("path", "Path of file containing training data."); NaturalParameter T = new NaturalParameter("t", "No. of trees in the forest (default = {0}).", 10); NaturalParameter D = new NaturalParameter("d", "Maximum tree levels (default = {0}).", 10, 20); NaturalParameter F = new NaturalParameter("f", "No. of candidate feature responses per decision node (default = {0}).", 10); NaturalParameter L = new NaturalParameter("l", "No. of candidate thresholds per feature response (default = {0}).", 1); SingleParameter a = new SingleParameter("a", "The number of 'effective' prior observations (default = {0}).", true, false, 10.0f); SingleParameter b = new SingleParameter("b", "The variance of the effective observations (default = {0}).", true, true, 400.0f); SimpleSwitchParameter verboseSwitch = new SimpleSwitchParameter("Enables verbose progress indication."); SingleParameter plotPaddingX = new SingleParameter("padx", "Pad plot horizontally (default = {0}).", true, false, 0.1f); SingleParameter plotPaddingY = new SingleParameter("pady", "Pad plot vertically (default = {0}).", true, false, 0.1f); EnumParameter split = new EnumParameter( "s", "Specify what kind of split function to use (default = {0}).", new string[] { "axis", "linear" }, new string[] { "axis-aligned split", "linear split" }, "axis"); // Behaviour depends on command line mode... string mode = args[0].ToLower(); // first argument defines the command line mode if (mode == "clas" || mode == "class") { #region Supervised classification CommandLineParser parser = new CommandLineParser(); parser.Command = "SW " + mode.ToUpper(); parser.AddArgument(trainingDataPath); parser.AddSwitch("T", T); parser.AddSwitch("D", D); parser.AddSwitch("F", F); parser.AddSwitch("L", L); parser.AddSwitch("SPLIT", split); parser.AddSwitch("PADX", plotPaddingX); parser.AddSwitch("PADY", plotPaddingY); parser.AddSwitch("VERBOSE", verboseSwitch); // Default values up above should be fine here. if (args.Length == 1) { parser.PrintHelp(); DisplayTextFiles(CLAS_DATA_PATH); return; } if (parser.Parse(args, 1) == false) { return; } TrainingParameters trainingParameters = new TrainingParameters() { MaxDecisionLevels = D.Value - 1, NumberOfCandidateFeatures = F.Value, NumberOfCandidateThresholdsPerFeature = L.Value, NumberOfTrees = T.Value, Verbose = verboseSwitch.Used }; PointF plotDilation = new PointF(plotPaddingX.Value, plotPaddingY.Value); DataPointCollection trainingData = LoadTrainingData( trainingDataPath.Value, CLAS_DATA_PATH, 2, DataDescriptor.HasClassLabels); if (split.Value == "linear") { Forest <LinearFeatureResponse2d, HistogramAggregator> forest = ClassificationExample.Train( trainingData, new LinearFeatureFactory(), trainingParameters); using (Bitmap result = ClassificationExample.Visualize(forest, trainingData, new Size(300, 300), plotDilation)) { ShowVisualizationImage(result); } } else if (split.Value == "axis") { Forest <AxisAlignedFeatureResponse, HistogramAggregator> forest = ClassificationExample.Train( trainingData, new AxisAlignedFeatureFactory(), trainingParameters); using (Bitmap result = ClassificationExample.Visualize(forest, trainingData, new Size(300, 300), plotDilation)) { ShowVisualizationImage(result); } } #endregion } else if (mode == "density") { #region Density Estimation CommandLineParser parser = new CommandLineParser(); parser.Command = "SW " + mode.ToUpper(); parser.AddArgument(trainingDataPath); parser.AddSwitch("T", T); parser.AddSwitch("D", D); parser.AddSwitch("F", F); parser.AddSwitch("L", L); // For density estimation (and semi-supervised learning) we add // a command line option to set the hyperparameters of the prior. parser.AddSwitch("a", a); parser.AddSwitch("b", b); parser.AddSwitch("PADX", plotPaddingX); parser.AddSwitch("PADY", plotPaddingY); parser.AddSwitch("VERBOSE", verboseSwitch); // Override default values for command line options. T.Value = 1; D.Value = 3; F.Value = 5; L.Value = 1; a.Value = 0; b.Value = 900; if (args.Length == 1) { parser.PrintHelp(); DisplayTextFiles(DENSITY_DATA_PATH); return; } if (parser.Parse(args, 1) == false) { return; } TrainingParameters parameters = new TrainingParameters() { MaxDecisionLevels = D.Value - 1, NumberOfCandidateFeatures = F.Value, NumberOfCandidateThresholdsPerFeature = L.Value, NumberOfTrees = T.Value, Verbose = verboseSwitch.Used }; DataPointCollection trainingData = LoadTrainingData( trainingDataPath.Value, DENSITY_DATA_PATH, 2, DataDescriptor.Unadorned); Forest <AxisAlignedFeatureResponse, GaussianAggregator2d> forest = DensityEstimationExample.Train(trainingData, parameters, a.Value, b.Value); PointF plotDilation = new PointF(plotPaddingX.Value, plotPaddingY.Value); using (Bitmap result = DensityEstimationExample.Visualize(forest, trainingData, new Size(300, 300), plotDilation)) { ShowVisualizationImage(result); } #endregion } else if (mode == "ssclas" || mode == "ssclas") { #region Semi-supervised classification CommandLineParser parser = new CommandLineParser(); parser.Command = "SW " + mode.ToUpper(); parser.AddArgument(trainingDataPath); parser.AddSwitch("T", T); parser.AddSwitch("D", D); parser.AddSwitch("F", F); parser.AddSwitch("L", L); parser.AddSwitch("split", split); parser.AddSwitch("a", a); parser.AddSwitch("b", b); EnumParameter plotMode = new EnumParameter( "plot", "Determines what to plot", new string[] { "density", "labels" }, new string[] { "plot recovered density estimate", "plot class likelihood" }, "labels"); parser.AddSwitch("plot", plotMode); parser.AddSwitch("PADX", plotPaddingX); parser.AddSwitch("PADY", plotPaddingY); parser.AddSwitch("VERBOSE", verboseSwitch); // Override default values for command line options. T.Value = 10; D.Value = 12 - 1; F.Value = 30; L.Value = 1; if (args.Length == 1) { parser.PrintHelp(); DisplayTextFiles(SSCLAS_DATA_PATH); return; } if (parser.Parse(args, 1) == false) { return; } DataPointCollection trainingData = LoadTrainingData( trainingDataPath.Value, SSCLAS_DATA_PATH, 2, DataDescriptor.HasClassLabels); TrainingParameters parameters = new TrainingParameters() { MaxDecisionLevels = D.Value - 1, NumberOfCandidateFeatures = F.Value, NumberOfCandidateThresholdsPerFeature = L.Value, NumberOfTrees = T.Value, Verbose = verboseSwitch.Used }; Forest <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> forest = SemiSupervisedClassificationExample.Train( trainingData, parameters, a.Value, b.Value); PointF plotPadding = new PointF(plotPaddingX.Value, plotPaddingY.Value); if (plotMode.Value == "labels") { using (Bitmap result = SemiSupervisedClassificationExample.VisualizeLabels(forest, trainingData, new Size(300, 300), plotPadding)) { ShowVisualizationImage(result); } } else if (plotMode.Value == "density") { using (Bitmap result = SemiSupervisedClassificationExample.VisualizeDensity(forest, trainingData, new Size(300, 300), plotPadding)) { ShowVisualizationImage(result); } } #endregion } else if (mode == "regression") { #region Regression CommandLineParser parser = new CommandLineParser(); parser.Command = "SW " + mode.ToUpper(); parser.AddArgument(trainingDataPath); parser.AddSwitch("T", T); parser.AddSwitch("D", D); parser.AddSwitch("F", F); parser.AddSwitch("L", L); parser.AddSwitch("PADX", plotPaddingX); parser.AddSwitch("PADY", plotPaddingY); parser.AddSwitch("VERBOSE", verboseSwitch); // Override default values for command line options T.Value = 10; D.Value = 2; a.Value = 0; // prior turned off by default b.Value = 900; if (args.Length == 1) { parser.PrintHelp(); DisplayTextFiles(REGRESSION_DATA_PATH); return; } if (parser.Parse(args, 1) == false) { return; } RegressionExample regressionDemo = new RegressionExample(); regressionDemo.PlotDilation.X = plotPaddingX.Value; regressionDemo.PlotDilation.Y = plotPaddingY.Value; regressionDemo.TrainingParameters = new TrainingParameters() { MaxDecisionLevels = D.Value - 1, NumberOfCandidateFeatures = F.Value, NumberOfCandidateThresholdsPerFeature = L.Value, NumberOfTrees = T.Value, Verbose = verboseSwitch.Used }; DataPointCollection trainingData = LoadTrainingData( trainingDataPath.Value, REGRESSION_DATA_PATH, 1, DataDescriptor.HasTargetValues); using (Bitmap result = regressionDemo.Run(trainingData)) { ShowVisualizationImage(result); } #endregion } else { Console.WriteLine("Unrecognized command line argument, try SW HELP."); return; } }
public static Forest <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> Train( DataPointCollection trainingData, TrainingParameters parameters, double a_, double b_) { // Train the forest Console.WriteLine("Training the forest..."); Random random = new Random(); ITrainingContext <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> classificationContext = new SemiSupervisedClassificationTrainingContext(trainingData.CountClasses(), random, a_, b_); var forest = ForestTrainer <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> .TrainForest( random, parameters, classificationContext, trainingData); // Label transduction to unlabelled leaves from nearest labelled leaf List <int> unlabelledLeafIndices = null; List <int> labelledLeafIndices = null; int[] closestLabelledLeafIndices = null; List <int> leafIndices = null; for (int t = 0; t < forest.TreeCount; t++) { var tree = forest.GetTree(t); leafIndices = new List <int>(); unlabelledLeafIndices = new List <int>(); labelledLeafIndices = new List <int>(); for (int n = 0; n < tree.NodeCount; n++) { if (tree.GetNode(n).IsLeaf) { if (tree.GetNode(n).TrainingDataStatistics.HistogramAggregator.SampleCount == 0) { unlabelledLeafIndices.Add(leafIndices.Count); } else { labelledLeafIndices.Add(leafIndices.Count); } leafIndices.Add(n); } } // Build an upper triangular matrix of inter-leaf distances float[,] interLeafDistances = new float[leafIndices.Count, leafIndices.Count]; for (int i = 0; i < leafIndices.Count; i++) { for (int j = i + 1; j < leafIndices.Count; j++) { SemiSupervisedClassificationStatisticsAggregator a = tree.GetNode(leafIndices[i]).TrainingDataStatistics; SemiSupervisedClassificationStatisticsAggregator b = tree.GetNode(leafIndices[j]).TrainingDataStatistics; GaussianPdf2d x = a.GaussianAggregator2d.GetPdf(); GaussianPdf2d y = b.GaussianAggregator2d.GetPdf(); interLeafDistances[i, j] = (float)(Math.Max( x.GetNegativeLogProbability((float)(y.MeanX), (float)(y.MeanY)), +y.GetNegativeLogProbability((float)(x.MeanX), (float)(x.MeanY)))); } } // Find shortest paths between all pairs of nodes in the graph of leaf nodes FloydWarshall pathFinder = new FloydWarshall(interLeafDistances); // Find the closest labelled leaf to each unlabelled leaf float[] minDistances = new float[unlabelledLeafIndices.Count]; closestLabelledLeafIndices = new int[unlabelledLeafIndices.Count]; for (int i = 0; i < minDistances.Length; i++) { minDistances[i] = float.PositiveInfinity; closestLabelledLeafIndices[i] = -1; // unused so deliberately invalid } for (int l = 0; l < labelledLeafIndices.Count; l++) { for (int u = 0; u < unlabelledLeafIndices.Count; u++) { if (pathFinder.GetMinimumDistance(unlabelledLeafIndices[u], labelledLeafIndices[l]) < minDistances[u]) { minDistances[u] = pathFinder.GetMinimumDistance(unlabelledLeafIndices[u], labelledLeafIndices[l]); closestLabelledLeafIndices[u] = leafIndices[labelledLeafIndices[l]]; } } } // Propagate class probability distributions to each unlabelled // leaf from its nearest labelled leaf. for (int u = 0; u < unlabelledLeafIndices.Count; u++) { // Unhelpfully, C# only allows us to pass value types by value // so Tree.GetNode() returns only a COPY of the Node. We update // this copy and then copy it back over the top of the // original via Tree.SetNode(). // The C++ version is a lot better! var unlabelledLeafCopy = tree.GetNode(leafIndices[unlabelledLeafIndices[u]]); var labelledLeafCopy = tree.GetNode(closestLabelledLeafIndices[u]); unlabelledLeafCopy.TrainingDataStatistics.HistogramAggregator = (HistogramAggregator)(labelledLeafCopy.TrainingDataStatistics.HistogramAggregator.DeepClone()); tree.SetNode(leafIndices[unlabelledLeafIndices[u]], unlabelledLeafCopy); } } return(forest); }