static void ComputeNormalizationFactorsRecurse( Tree <AxisAlignedFeatureResponse, GaussianAggregator2d> t, int nodeIndex, int nTrainingPoints, Bounds bounds, double[] normalizationFactors) { GaussianPdf2d g = t.GetNode(nodeIndex).TrainingDataStatistics.GetPdf(); // Evaluate integral of bivariate normal distribution within this node's bounds double u = CumulativeNormalDistribution2d.M( (bounds.Upper[0] - g.MeanX) / Math.Sqrt(g.VarianceX), (bounds.Upper[1] - g.MeanY) / Math.Sqrt(g.VarianceY), g.CovarianceXY / Math.Sqrt(g.VarianceX * g.VarianceY)); double l = CumulativeNormalDistribution2d.M( (bounds.Lower[0] - g.MeanX) / Math.Sqrt(g.VarianceX), (bounds.Lower[1] - g.MeanY) / Math.Sqrt(g.VarianceY), g.CovarianceXY / Math.Sqrt(g.VarianceX * g.VarianceY)); normalizationFactors[nodeIndex] = (double)(t.GetNode(nodeIndex).TrainingDataStatistics.SampleCount) / nTrainingPoints * 1.0 / (u - l); if (!t.GetNode(nodeIndex).IsLeaf) { Bounds leftChildBounds = bounds.Clone(); leftChildBounds.Upper[t.GetNode(nodeIndex).Feature.Axis] = t.GetNode(nodeIndex).Threshold; ComputeNormalizationFactorsRecurse(t, nodeIndex * 2 + 1, nTrainingPoints, leftChildBounds, normalizationFactors); Bounds rightChildBounds = bounds.Clone(); rightChildBounds.Lower[t.GetNode(nodeIndex).Feature.Axis] = t.GetNode(nodeIndex).Threshold; ComputeNormalizationFactorsRecurse(t, nodeIndex * 2 + 2, nTrainingPoints, rightChildBounds, normalizationFactors); } }
public static Bitmap VisualizeLabels(Forest <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> forest, DataPointCollection trainingData, Size PlotSize, PointF PlotDilation) { // Generate some test samples in a grid pattern (a useful basis for creating visualization images) PlotCanvas plotCanvas = new PlotCanvas(trainingData.GetRange(0), trainingData.GetRange(1), PlotSize, PlotDilation); // Apply the trained forest to the test data Console.WriteLine("\nApplying the forest to test data..."); DataPointCollection testData = DataPointCollection.Generate2dGrid(plotCanvas.plotRangeX, PlotSize.Width, plotCanvas.plotRangeY, PlotSize.Height); int[][] leafNodeIndices = forest.Apply(testData); Bitmap result = new Bitmap(PlotSize.Width, PlotSize.Height); // Paint the test data GaussianPdf2d[][] leafDistributions = new GaussianPdf2d[forest.TreeCount][]; for (int t = 0; t < forest.TreeCount; t++) { leafDistributions[t] = new GaussianPdf2d[forest.GetTree(t).NodeCount]; for (int i = 0; i < forest.GetTree(t).NodeCount; i++) { Node <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> nodeCopy = forest.GetTree(t).GetNode(i); if (nodeCopy.IsLeaf) { leafDistributions[t][i] = nodeCopy.TrainingDataStatistics.GaussianAggregator2d.GetPdf(); } } } // Form a palette of random colors, one per class Color[] colors = new Color[Math.Max(trainingData.CountClasses(), 4)]; // First few colours are same as those in the book, remainder are random. colors[0] = Color.FromArgb(183, 170, 8); colors[1] = Color.FromArgb(194, 32, 14); colors[2] = Color.FromArgb(4, 154, 10); colors[3] = Color.FromArgb(13, 26, 188); Color grey = Color.FromArgb(255, 127, 127, 127); System.Random r = new Random(0); // same seed every time so colours will be consistent for (int c = 4; c < colors.Length; c++) { colors[c] = Color.FromArgb(255, r.Next(0, 255), r.Next(0, 255), r.Next(0, 255)); } int index = 0; for (int j = 0; j < PlotSize.Height; j++) { for (int i = 0; i < PlotSize.Width; i++) { // Aggregate statistics for this sample over all leaf nodes reached HistogramAggregator h = new HistogramAggregator(trainingData.CountClasses()); for (int t = 0; t < forest.TreeCount; t++) { int leafIndex = leafNodeIndices[t][index]; SemiSupervisedClassificationStatisticsAggregator a = forest.GetTree(t).GetNode(leafIndex).TrainingDataStatistics; h.Aggregate(a.HistogramAggregator); } // Let's muddy the colors with a little grey where entropy is high. float mudiness = 0.5f * (float)(h.Entropy()); float R = 0.0f, G = 0.0f, B = 0.0f; for (int b = 0; b < trainingData.CountClasses(); b++) { float p = (1.0f - mudiness) * h.GetProbability(b); // NB probabilities sum to 1.0 over the classes R += colors[b].R * p; G += colors[b].G * p; B += colors[b].B * p; } R += grey.R * mudiness; G += grey.G * mudiness; B += grey.B * mudiness; Color c = Color.FromArgb(255, (byte)(R), (byte)(G), (byte)(B)); result.SetPixel(i, j, c); index++; } } PaintTrainingData(trainingData, plotCanvas, result); return(result); }
public static Forest <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> Train( DataPointCollection trainingData, TrainingParameters parameters, double a_, double b_) { // Train the forest Console.WriteLine("Training the forest..."); Random random = new Random(); ITrainingContext <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> classificationContext = new SemiSupervisedClassificationTrainingContext(trainingData.CountClasses(), random, a_, b_); var forest = ForestTrainer <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> .TrainForest( random, parameters, classificationContext, trainingData); // Label transduction to unlabelled leaves from nearest labelled leaf List <int> unlabelledLeafIndices = null; List <int> labelledLeafIndices = null; int[] closestLabelledLeafIndices = null; List <int> leafIndices = null; for (int t = 0; t < forest.TreeCount; t++) { var tree = forest.GetTree(t); leafIndices = new List <int>(); unlabelledLeafIndices = new List <int>(); labelledLeafIndices = new List <int>(); for (int n = 0; n < tree.NodeCount; n++) { if (tree.GetNode(n).IsLeaf) { if (tree.GetNode(n).TrainingDataStatistics.HistogramAggregator.SampleCount == 0) { unlabelledLeafIndices.Add(leafIndices.Count); } else { labelledLeafIndices.Add(leafIndices.Count); } leafIndices.Add(n); } } // Build an upper triangular matrix of inter-leaf distances float[,] interLeafDistances = new float[leafIndices.Count, leafIndices.Count]; for (int i = 0; i < leafIndices.Count; i++) { for (int j = i + 1; j < leafIndices.Count; j++) { SemiSupervisedClassificationStatisticsAggregator a = tree.GetNode(leafIndices[i]).TrainingDataStatistics; SemiSupervisedClassificationStatisticsAggregator b = tree.GetNode(leafIndices[j]).TrainingDataStatistics; GaussianPdf2d x = a.GaussianAggregator2d.GetPdf(); GaussianPdf2d y = b.GaussianAggregator2d.GetPdf(); interLeafDistances[i, j] = (float)(Math.Max( x.GetNegativeLogProbability((float)(y.MeanX), (float)(y.MeanY)), +y.GetNegativeLogProbability((float)(x.MeanX), (float)(x.MeanY)))); } } // Find shortest paths between all pairs of nodes in the graph of leaf nodes FloydWarshall pathFinder = new FloydWarshall(interLeafDistances); // Find the closest labelled leaf to each unlabelled leaf float[] minDistances = new float[unlabelledLeafIndices.Count]; closestLabelledLeafIndices = new int[unlabelledLeafIndices.Count]; for (int i = 0; i < minDistances.Length; i++) { minDistances[i] = float.PositiveInfinity; closestLabelledLeafIndices[i] = -1; // unused so deliberately invalid } for (int l = 0; l < labelledLeafIndices.Count; l++) { for (int u = 0; u < unlabelledLeafIndices.Count; u++) { if (pathFinder.GetMinimumDistance(unlabelledLeafIndices[u], labelledLeafIndices[l]) < minDistances[u]) { minDistances[u] = pathFinder.GetMinimumDistance(unlabelledLeafIndices[u], labelledLeafIndices[l]); closestLabelledLeafIndices[u] = leafIndices[labelledLeafIndices[l]]; } } } // Propagate class probability distributions to each unlabelled // leaf from its nearest labelled leaf. for (int u = 0; u < unlabelledLeafIndices.Count; u++) { // Unhelpfully, C# only allows us to pass value types by value // so Tree.GetNode() returns only a COPY of the Node. We update // this copy and then copy it back over the top of the // original via Tree.SetNode(). // The C++ version is a lot better! var unlabelledLeafCopy = tree.GetNode(leafIndices[unlabelledLeafIndices[u]]); var labelledLeafCopy = tree.GetNode(closestLabelledLeafIndices[u]); unlabelledLeafCopy.TrainingDataStatistics.HistogramAggregator = (HistogramAggregator)(labelledLeafCopy.TrainingDataStatistics.HistogramAggregator.DeepClone()); tree.SetNode(leafIndices[unlabelledLeafIndices[u]], unlabelledLeafCopy); } } return(forest); }