static public Forest <AxisAlignedFeatureResponse, GaussianAggregator2d> Train( DataPointCollection trainingData, TrainingParameters parameters, double a, double b) { if (trainingData.Dimensions != 2) { throw new Exception("Training data points for density estimation were not 2D."); } if (trainingData.HasLabels == true) { throw new Exception("Density estimation training data should not be labelled."); } if (trainingData.HasTargetValues == true) { throw new Exception("Training data should not have target values."); } // Train the forest Console.WriteLine("Training the forest..."); Random random = new Random(); ITrainingContext <AxisAlignedFeatureResponse, GaussianAggregator2d> densityEstimationTrainingContext = new DensityEstimationTrainingContext(a, b); var forest = ForestTrainer <AxisAlignedFeatureResponse, GaussianAggregator2d> .TrainForest( random, parameters, densityEstimationTrainingContext, trainingData); return(forest); }
static public Forest <F, HistogramAggregator> Train <F>( DataPointCollection trainingData, IFeatureFactory <F> featureFactory, TrainingParameters TrainingParameters) where F : IFeatureResponse { if (trainingData.Dimensions != 2) { throw new Exception("Training data points must be 2D."); } if (trainingData.HasLabels == false) { throw new Exception("Training data points must be labelled."); } if (trainingData.HasTargetValues == true) { throw new Exception("Training data points should not have target values."); } Console.WriteLine("Running training..."); Random random = new Random(); ITrainingContext <F, HistogramAggregator> classificationContext = new ClassificationTrainingContext <F>(trainingData.CountClasses(), featureFactory, random); var forest = ForestTrainer <F, HistogramAggregator> .TrainForest( random, TrainingParameters, classificationContext, trainingData); return(forest); }
public Bitmap Run(DataPointCollection trainingData) { // Train the forest Console.WriteLine("Training the forest..."); Random random = new Random(); ITrainingContext <AxisAlignedFeatureResponse, LinearFitAggregator1d> regressionTrainingContext = new RegressionTrainingContext(); var forest = ForestTrainer <AxisAlignedFeatureResponse, LinearFitAggregator1d> .TrainForest( random, TrainingParameters, regressionTrainingContext, trainingData); // Generate some test samples in a grid pattern (a useful basis for creating visualization images) PlotCanvas plotCanvas = new PlotCanvas(trainingData.GetRange(0), trainingData.GetTargetRange(), PlotSize, PlotDilation); DataPointCollection testData = DataPointCollection.Generate1dGrid(plotCanvas.plotRangeX, PlotSize.Width); // Apply the trained forest to the test data Console.WriteLine("\nApplying the forest to test data..."); int[][] leafNodeIndices = forest.Apply(testData); #region Generate Visualization Image Bitmap result = new Bitmap(PlotSize.Width, PlotSize.Height); // Plot the learned density Color inverseDensityColor = Color.FromArgb(255, 255 - DensityColor.R, 255 - DensityColor.G, 255 - DensityColor.B); double[] mean_y_given_x = new double[PlotSize.Width]; int index = 0; for (int i = 0; i < PlotSize.Width; i++) { double totalProbability = 0.0; for (int j = 0; j < PlotSize.Height; j++) { // Map pixel coordinate (i,j) in visualization image back to point in input space float x = plotCanvas.plotRangeX.Item1 + i * plotCanvas.stepX; float y = plotCanvas.plotRangeY.Item1 + j * plotCanvas.stepY; double probability = 0.0; // Aggregate statistics for this sample over all trees for (int t = 0; t < forest.TreeCount; t++) { Node <AxisAlignedFeatureResponse, LinearFitAggregator1d> leafNodeCopy = forest.GetTree(t).GetNode(leafNodeIndices[t][i]); LinearFitAggregator1d leafStatistics = leafNodeCopy.TrainingDataStatistics; probability += leafStatistics.GetProbability(x, y); } probability /= forest.TreeCount; mean_y_given_x[i] += probability * y; totalProbability += probability; float scale = 10.0f * (float)probability; Color weightedColor = Color.FromArgb( 255, (byte)(Math.Min(scale * inverseDensityColor.R + 0.5f, 255.0f)), (byte)(Math.Min(scale * inverseDensityColor.G + 0.5f, 255.0f)), (byte)(Math.Min(scale * inverseDensityColor.B + 0.5f, 255.0f))); Color c = Color.FromArgb(255, 255 - weightedColor.R, 255 - weightedColor.G, 255 - weightedColor.G); result.SetPixel(i, j, c); index++; } // NB We don't really compute the mean over y, just over the region of y that is plotted mean_y_given_x[i] /= totalProbability; } // Also plot the mean curve and the original training data using (Graphics g = Graphics.FromImage(result)) { g.InterpolationMode = System.Drawing.Drawing2D.InterpolationMode.HighQualityBicubic; g.SmoothingMode = System.Drawing.Drawing2D.SmoothingMode.HighQuality; using (Pen meanPen = new Pen(MeanColor, 2)) { for (int i = 0; i < PlotSize.Width - 1; i++) { g.DrawLine( meanPen, (float)(i), (float)((mean_y_given_x[i] - plotCanvas.plotRangeY.Item1) / plotCanvas.stepY), (float)(i + 1), (float)((mean_y_given_x[i + 1] - plotCanvas.plotRangeY.Item1) / plotCanvas.stepY)); } } using (Brush dataPointBrush = new SolidBrush(DataPointColor)) using (Pen dataPointBorderPen = new Pen(DataPointBorderColor)) { for (int s = 0; s < trainingData.Count(); s++) { // Map sample coordinate back to a pixel coordinate in the visualization image PointF x = new PointF( (trainingData.GetDataPoint(s)[0] - plotCanvas.plotRangeX.Item1) / plotCanvas.stepX, (trainingData.GetTarget(s) - plotCanvas.plotRangeY.Item1) / plotCanvas.stepY); RectangleF rectangle = new RectangleF(x.X - 2.0f, x.Y - 2.0f, 4.0f, 4.0f); g.FillRectangle(dataPointBrush, rectangle); g.DrawRectangle(dataPointBorderPen, rectangle.X, rectangle.Y, rectangle.Width, rectangle.Height); } } } return(result); #endregion }
public static Forest <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> Train( DataPointCollection trainingData, TrainingParameters parameters, double a_, double b_) { // Train the forest Console.WriteLine("Training the forest..."); Random random = new Random(); ITrainingContext <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> classificationContext = new SemiSupervisedClassificationTrainingContext(trainingData.CountClasses(), random, a_, b_); var forest = ForestTrainer <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> .TrainForest( random, parameters, classificationContext, trainingData); // Label transduction to unlabelled leaves from nearest labelled leaf List <int> unlabelledLeafIndices = null; List <int> labelledLeafIndices = null; int[] closestLabelledLeafIndices = null; List <int> leafIndices = null; for (int t = 0; t < forest.TreeCount; t++) { var tree = forest.GetTree(t); leafIndices = new List <int>(); unlabelledLeafIndices = new List <int>(); labelledLeafIndices = new List <int>(); for (int n = 0; n < tree.NodeCount; n++) { if (tree.GetNode(n).IsLeaf) { if (tree.GetNode(n).TrainingDataStatistics.HistogramAggregator.SampleCount == 0) { unlabelledLeafIndices.Add(leafIndices.Count); } else { labelledLeafIndices.Add(leafIndices.Count); } leafIndices.Add(n); } } // Build an upper triangular matrix of inter-leaf distances float[,] interLeafDistances = new float[leafIndices.Count, leafIndices.Count]; for (int i = 0; i < leafIndices.Count; i++) { for (int j = i + 1; j < leafIndices.Count; j++) { SemiSupervisedClassificationStatisticsAggregator a = tree.GetNode(leafIndices[i]).TrainingDataStatistics; SemiSupervisedClassificationStatisticsAggregator b = tree.GetNode(leafIndices[j]).TrainingDataStatistics; GaussianPdf2d x = a.GaussianAggregator2d.GetPdf(); GaussianPdf2d y = b.GaussianAggregator2d.GetPdf(); interLeafDistances[i, j] = (float)(Math.Max( x.GetNegativeLogProbability((float)(y.MeanX), (float)(y.MeanY)), +y.GetNegativeLogProbability((float)(x.MeanX), (float)(x.MeanY)))); } } // Find shortest paths between all pairs of nodes in the graph of leaf nodes FloydWarshall pathFinder = new FloydWarshall(interLeafDistances); // Find the closest labelled leaf to each unlabelled leaf float[] minDistances = new float[unlabelledLeafIndices.Count]; closestLabelledLeafIndices = new int[unlabelledLeafIndices.Count]; for (int i = 0; i < minDistances.Length; i++) { minDistances[i] = float.PositiveInfinity; closestLabelledLeafIndices[i] = -1; // unused so deliberately invalid } for (int l = 0; l < labelledLeafIndices.Count; l++) { for (int u = 0; u < unlabelledLeafIndices.Count; u++) { if (pathFinder.GetMinimumDistance(unlabelledLeafIndices[u], labelledLeafIndices[l]) < minDistances[u]) { minDistances[u] = pathFinder.GetMinimumDistance(unlabelledLeafIndices[u], labelledLeafIndices[l]); closestLabelledLeafIndices[u] = leafIndices[labelledLeafIndices[l]]; } } } // Propagate class probability distributions to each unlabelled // leaf from its nearest labelled leaf. for (int u = 0; u < unlabelledLeafIndices.Count; u++) { // Unhelpfully, C# only allows us to pass value types by value // so Tree.GetNode() returns only a COPY of the Node. We update // this copy and then copy it back over the top of the // original via Tree.SetNode(). // The C++ version is a lot better! var unlabelledLeafCopy = tree.GetNode(leafIndices[unlabelledLeafIndices[u]]); var labelledLeafCopy = tree.GetNode(closestLabelledLeafIndices[u]); unlabelledLeafCopy.TrainingDataStatistics.HistogramAggregator = (HistogramAggregator)(labelledLeafCopy.TrainingDataStatistics.HistogramAggregator.DeepClone()); tree.SetNode(leafIndices[unlabelledLeafIndices[u]], unlabelledLeafCopy); } } return(forest); }