static public Forest <AxisAlignedFeatureResponse, GaussianAggregator2d> Train(
            DataPointCollection trainingData,
            TrainingParameters parameters,
            double a,
            double b)
        {
            if (trainingData.Dimensions != 2)
            {
                throw new Exception("Training data points for density estimation were not 2D.");
            }
            if (trainingData.HasLabels == true)
            {
                throw new Exception("Density estimation training data should not be labelled.");
            }
            if (trainingData.HasTargetValues == true)
            {
                throw new Exception("Training data should not have target values.");
            }

            // Train the forest
            Console.WriteLine("Training the forest...");

            Random random = new Random();

            ITrainingContext <AxisAlignedFeatureResponse, GaussianAggregator2d> densityEstimationTrainingContext =
                new DensityEstimationTrainingContext(a, b);
            var forest = ForestTrainer <AxisAlignedFeatureResponse, GaussianAggregator2d> .TrainForest(
                random,
                parameters,
                densityEstimationTrainingContext,
                trainingData);

            return(forest);
        }
        static public Forest <F, HistogramAggregator> Train <F>(
            DataPointCollection trainingData,
            IFeatureFactory <F> featureFactory,
            TrainingParameters TrainingParameters) where F : IFeatureResponse
        {
            if (trainingData.Dimensions != 2)
            {
                throw new Exception("Training data points must be 2D.");
            }
            if (trainingData.HasLabels == false)
            {
                throw new Exception("Training data points must be labelled.");
            }
            if (trainingData.HasTargetValues == true)
            {
                throw new Exception("Training data points should not have target values.");
            }

            Console.WriteLine("Running training...");

            Random random = new Random();
            ITrainingContext <F, HistogramAggregator> classificationContext =
                new ClassificationTrainingContext <F>(trainingData.CountClasses(), featureFactory, random);

            var forest = ForestTrainer <F, HistogramAggregator> .TrainForest(
                random,
                TrainingParameters,
                classificationContext,
                trainingData);

            return(forest);
        }
Esempio n. 3
0
        public Bitmap Run(DataPointCollection trainingData)
        {
            // Train the forest
            Console.WriteLine("Training the forest...");

            Random random = new Random();
            ITrainingContext <AxisAlignedFeatureResponse, LinearFitAggregator1d> regressionTrainingContext = new RegressionTrainingContext();

            var forest = ForestTrainer <AxisAlignedFeatureResponse, LinearFitAggregator1d> .TrainForest(
                random,
                TrainingParameters,
                regressionTrainingContext,
                trainingData);

            // Generate some test samples in a grid pattern (a useful basis for creating visualization images)
            PlotCanvas plotCanvas = new PlotCanvas(trainingData.GetRange(0), trainingData.GetTargetRange(), PlotSize, PlotDilation);

            DataPointCollection testData = DataPointCollection.Generate1dGrid(plotCanvas.plotRangeX, PlotSize.Width);

            // Apply the trained forest to the test data
            Console.WriteLine("\nApplying the forest to test data...");

            int[][] leafNodeIndices = forest.Apply(testData);

            #region Generate Visualization Image
            Bitmap result = new Bitmap(PlotSize.Width, PlotSize.Height);

            // Plot the learned density
            Color inverseDensityColor = Color.FromArgb(255, 255 - DensityColor.R, 255 - DensityColor.G, 255 - DensityColor.B);

            double[] mean_y_given_x = new double[PlotSize.Width];

            int index = 0;
            for (int i = 0; i < PlotSize.Width; i++)
            {
                double totalProbability = 0.0;
                for (int j = 0; j < PlotSize.Height; j++)
                {
                    // Map pixel coordinate (i,j) in visualization image back to point in input space
                    float x = plotCanvas.plotRangeX.Item1 + i * plotCanvas.stepX;
                    float y = plotCanvas.plotRangeY.Item1 + j * plotCanvas.stepY;

                    double probability = 0.0;

                    // Aggregate statistics for this sample over all trees
                    for (int t = 0; t < forest.TreeCount; t++)
                    {
                        Node <AxisAlignedFeatureResponse, LinearFitAggregator1d> leafNodeCopy = forest.GetTree(t).GetNode(leafNodeIndices[t][i]);

                        LinearFitAggregator1d leafStatistics = leafNodeCopy.TrainingDataStatistics;

                        probability += leafStatistics.GetProbability(x, y);
                    }

                    probability /= forest.TreeCount;

                    mean_y_given_x[i] += probability * y;
                    totalProbability  += probability;

                    float scale = 10.0f * (float)probability;

                    Color weightedColor = Color.FromArgb(
                        255,
                        (byte)(Math.Min(scale * inverseDensityColor.R + 0.5f, 255.0f)),
                        (byte)(Math.Min(scale * inverseDensityColor.G + 0.5f, 255.0f)),
                        (byte)(Math.Min(scale * inverseDensityColor.B + 0.5f, 255.0f)));

                    Color c = Color.FromArgb(255, 255 - weightedColor.R, 255 - weightedColor.G, 255 - weightedColor.G);

                    result.SetPixel(i, j, c);

                    index++;
                }

                // NB We don't really compute the mean over y, just over the region of y that is plotted
                mean_y_given_x[i] /= totalProbability;
            }

            // Also plot the mean curve and the original training data
            using (Graphics g = Graphics.FromImage(result))
            {
                g.InterpolationMode = System.Drawing.Drawing2D.InterpolationMode.HighQualityBicubic;
                g.SmoothingMode     = System.Drawing.Drawing2D.SmoothingMode.HighQuality;

                using (Pen meanPen = new Pen(MeanColor, 2))
                {
                    for (int i = 0; i < PlotSize.Width - 1; i++)
                    {
                        g.DrawLine(
                            meanPen,
                            (float)(i),
                            (float)((mean_y_given_x[i] - plotCanvas.plotRangeY.Item1) / plotCanvas.stepY),
                            (float)(i + 1),
                            (float)((mean_y_given_x[i + 1] - plotCanvas.plotRangeY.Item1) / plotCanvas.stepY));
                    }
                }

                using (Brush dataPointBrush = new SolidBrush(DataPointColor))
                    using (Pen dataPointBorderPen = new Pen(DataPointBorderColor))
                    {
                        for (int s = 0; s < trainingData.Count(); s++)
                        {
                            // Map sample coordinate back to a pixel coordinate in the visualization image
                            PointF x = new PointF(
                                (trainingData.GetDataPoint(s)[0] - plotCanvas.plotRangeX.Item1) / plotCanvas.stepX,
                                (trainingData.GetTarget(s) - plotCanvas.plotRangeY.Item1) / plotCanvas.stepY);

                            RectangleF rectangle = new RectangleF(x.X - 2.0f, x.Y - 2.0f, 4.0f, 4.0f);
                            g.FillRectangle(dataPointBrush, rectangle);
                            g.DrawRectangle(dataPointBorderPen, rectangle.X, rectangle.Y, rectangle.Width, rectangle.Height);
                        }
                    }
            }

            return(result);

            #endregion
        }
        public static Forest <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> Train(
            DataPointCollection trainingData,
            TrainingParameters parameters,
            double a_,
            double b_)
        {
            // Train the forest
            Console.WriteLine("Training the forest...");

            Random random = new Random();

            ITrainingContext <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> classificationContext
                = new SemiSupervisedClassificationTrainingContext(trainingData.CountClasses(), random, a_, b_);
            var forest = ForestTrainer <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> .TrainForest(
                random,
                parameters,
                classificationContext,
                trainingData);

            // Label transduction to unlabelled leaves from nearest labelled leaf
            List <int> unlabelledLeafIndices = null;
            List <int> labelledLeafIndices   = null;

            int[]      closestLabelledLeafIndices = null;
            List <int> leafIndices = null;

            for (int t = 0; t < forest.TreeCount; t++)
            {
                var tree = forest.GetTree(t);
                leafIndices = new List <int>();

                unlabelledLeafIndices = new List <int>();
                labelledLeafIndices   = new List <int>();

                for (int n = 0; n < tree.NodeCount; n++)
                {
                    if (tree.GetNode(n).IsLeaf)
                    {
                        if (tree.GetNode(n).TrainingDataStatistics.HistogramAggregator.SampleCount == 0)
                        {
                            unlabelledLeafIndices.Add(leafIndices.Count);
                        }
                        else
                        {
                            labelledLeafIndices.Add(leafIndices.Count);
                        }

                        leafIndices.Add(n);
                    }
                }

                // Build an upper triangular matrix of inter-leaf distances
                float[,] interLeafDistances = new float[leafIndices.Count, leafIndices.Count];
                for (int i = 0; i < leafIndices.Count; i++)
                {
                    for (int j = i + 1; j < leafIndices.Count; j++)
                    {
                        SemiSupervisedClassificationStatisticsAggregator a = tree.GetNode(leafIndices[i]).TrainingDataStatistics;
                        SemiSupervisedClassificationStatisticsAggregator b = tree.GetNode(leafIndices[j]).TrainingDataStatistics;
                        GaussianPdf2d x = a.GaussianAggregator2d.GetPdf();
                        GaussianPdf2d y = b.GaussianAggregator2d.GetPdf();

                        interLeafDistances[i, j] = (float)(Math.Max(
                                                               x.GetNegativeLogProbability((float)(y.MeanX), (float)(y.MeanY)),
                                                               +y.GetNegativeLogProbability((float)(x.MeanX), (float)(x.MeanY))));
                    }
                }

                // Find shortest paths between all pairs of nodes in the graph of leaf nodes
                FloydWarshall pathFinder = new FloydWarshall(interLeafDistances);

                // Find the closest labelled leaf to each unlabelled leaf
                float[] minDistances = new float[unlabelledLeafIndices.Count];
                closestLabelledLeafIndices = new int[unlabelledLeafIndices.Count];
                for (int i = 0; i < minDistances.Length; i++)
                {
                    minDistances[i] = float.PositiveInfinity;
                    closestLabelledLeafIndices[i] = -1; // unused so deliberately invalid
                }

                for (int l = 0; l < labelledLeafIndices.Count; l++)
                {
                    for (int u = 0; u < unlabelledLeafIndices.Count; u++)
                    {
                        if (pathFinder.GetMinimumDistance(unlabelledLeafIndices[u], labelledLeafIndices[l]) < minDistances[u])
                        {
                            minDistances[u] = pathFinder.GetMinimumDistance(unlabelledLeafIndices[u], labelledLeafIndices[l]);
                            closestLabelledLeafIndices[u] = leafIndices[labelledLeafIndices[l]];
                        }
                    }
                }

                // Propagate class probability distributions to each unlabelled
                // leaf from its nearest labelled leaf.
                for (int u = 0; u < unlabelledLeafIndices.Count; u++)
                {
                    // Unhelpfully, C# only allows us to pass value types by value
                    // so Tree.GetNode() returns only a COPY of the Node. We update
                    // this copy and then copy it back over the top of the
                    // original via Tree.SetNode().

                    // The C++ version is a lot better!

                    var unlabelledLeafCopy = tree.GetNode(leafIndices[unlabelledLeafIndices[u]]);
                    var labelledLeafCopy   = tree.GetNode(closestLabelledLeafIndices[u]);

                    unlabelledLeafCopy.TrainingDataStatistics.HistogramAggregator
                        = (HistogramAggregator)(labelledLeafCopy.TrainingDataStatistics.HistogramAggregator.DeepClone());

                    tree.SetNode(leafIndices[unlabelledLeafIndices[u]], unlabelledLeafCopy);
                }
            }

            return(forest);
        }