/// <summary>
        /// Train a new decision forest given some training data and a training
        /// problem described by an instance of the ITrainingContext interface.
        /// </summary>
        /// <param name="random">Random number generator.</param>
        /// <param name="parameters">Training parameters.</param>
        /// <param name="maxThreads">The maximum number of threads to use.</param>
        /// <param name="context">An ITrainingContext instance describing
        /// the training problem, e.g. classification, density estimation, etc. </param>
        /// <param name="data">The training data.</param>
        /// <returns>A new decision forest.</returns>
        public static Forest <F, S> TrainForest(
            Random random,
            TrainingParameters parameters,
            ITrainingContext <F, S> context,
            int maxThreads,
            IDataPointCollection data,
            ProgressWriter progress = null)
        {
            if (progress == null)
            {
                progress = new ProgressWriter(parameters.Verbose?Verbosity.Verbose:Verbosity.Interest, Console.Out);
            }

            Forest <F, S> forest = new Forest <F, S>();

            for (int t = 0; t < parameters.NumberOfTrees; t++)
            {
                progress.Write(Verbosity.Interest, "\rTraining tree {0}...", t);

                Tree <F, S> tree = ParallelTreeTrainer <F, S> .TrainTree(random, context, parameters, maxThreads, data, progress);

                forest.AddTree(tree);
            }
            progress.WriteLine(Verbosity.Interest, "\rTrained {0} trees.         ", parameters.NumberOfTrees);

            return(forest);
        }
        static public Forest <AxisAlignedFeatureResponse, GaussianAggregator2d> Train(
            DataPointCollection trainingData,
            TrainingParameters parameters,
            double a,
            double b)
        {
            if (trainingData.Dimensions != 2)
            {
                throw new Exception("Training data points for density estimation were not 2D.");
            }
            if (trainingData.HasLabels == true)
            {
                throw new Exception("Density estimation training data should not be labelled.");
            }
            if (trainingData.HasTargetValues == true)
            {
                throw new Exception("Training data should not have target values.");
            }

            // Train the forest
            Console.WriteLine("Training the forest...");

            Random random = new Random();

            ITrainingContext <AxisAlignedFeatureResponse, GaussianAggregator2d> densityEstimationTrainingContext =
                new DensityEstimationTrainingContext(a, b);
            var forest = ForestTrainer <AxisAlignedFeatureResponse, GaussianAggregator2d> .TrainForest(
                random,
                parameters,
                densityEstimationTrainingContext,
                trainingData);

            return(forest);
        }
        static public Forest <F, HistogramAggregator> Train <F>(
            DataPointCollection trainingData,
            IFeatureFactory <F> featureFactory,
            TrainingParameters TrainingParameters) where F : IFeatureResponse
        {
            if (trainingData.Dimensions != 2)
            {
                throw new Exception("Training data points must be 2D.");
            }
            if (trainingData.HasLabels == false)
            {
                throw new Exception("Training data points must be labelled.");
            }
            if (trainingData.HasTargetValues == true)
            {
                throw new Exception("Training data points should not have target values.");
            }

            Console.WriteLine("Running training...");

            Random random = new Random();
            ITrainingContext <F, HistogramAggregator> classificationContext =
                new ClassificationTrainingContext <F>(trainingData.CountClasses(), featureFactory, random);

            var forest = ForestTrainer <F, HistogramAggregator> .TrainForest(
                random,
                TrainingParameters,
                classificationContext,
                trainingData);

            return(forest);
        }
        public TreeTrainingOperation(
            Random randomNumberGenerator,
            ITrainingContext <F, S> trainingContext,
            TrainingParameters parameters,
            IDataPointCollection data,
            ProgressWriter progress)
        {
            data_            = data;
            trainingContext_ = trainingContext;
            parameters_      = parameters;

            random_   = randomNumberGenerator;
            progress_ = progress;

            indices_ = new int[data.Count()];
            for (int i = 0; i < indices_.Length; i++)
            {
                indices_[i] = i;
            }

            responses_ = new float[data.Count()];

            parentStatistics_ = trainingContext_.GetStatisticsAggregator();

            leftChildStatistics_  = trainingContext_.GetStatisticsAggregator();
            rightChildStatistics_ = trainingContext_.GetStatisticsAggregator();

            partitionStatistics_ = new S[parameters.NumberOfCandidateThresholdsPerFeature + 1];
            for (int i = 0; i < parameters.NumberOfCandidateThresholdsPerFeature + 1; i++)
            {
                partitionStatistics_[i] = trainingContext_.GetStatisticsAggregator();
            }
        }
        /// <summary>
        /// Train a new decision tree given some training data and a training
        /// problem described by an ITrainingContext instance.
        /// </summary>
        /// <param name="random">The single random number generator.</param>
        /// <param name="progress">Progress reporting target.</param>
        /// <param name="context">The ITrainingContext instance by which
        /// the training framework interacts with the training data.
        /// Implemented within client code.</param>
        /// <param name="parameters">Training parameters.</param>
        /// <param name="data">The training data.</param>
        /// <returns>A new decision tree.</returns>
        static public Tree <F, S> TrainTree(
            Random random,
            ITrainingContext <F, S> context,
            TrainingParameters parameters,
            IDataPointCollection data,
            ProgressWriter progress = null)
        {
            if (progress == null)
            {
                progress = new ProgressWriter(Verbosity.Interest, Console.Out);
            }
            TreeTrainingOperation <F, S> trainingOperation = new TreeTrainingOperation <F, S>(
                random, context, parameters, data, progress);

            Tree <F, S> tree = new Tree <F, S>(parameters.MaxDecisionLevels);

            progress.WriteLine(Verbosity.Verbose, "");

            trainingOperation.TrainNodesRecurse(tree.nodes_, 0, 0, data.Count(), 0); // will recurse until termination criterion is met

            progress.WriteLine(Verbosity.Verbose, "");

            tree.CheckValid();

            return(tree);
        }
        public ParallelTreeTrainingOperation(
            Random random,
            ITrainingContext <F, S> trainingContext,
            TrainingParameters parameters,
            int maxThreads,
            IDataPointCollection data,
            ProgressWriter progress)
        {
            data_            = data;
            trainingContext_ = trainingContext;
            parameters_      = parameters;

            maxThreads_ = maxThreads;

            random_ = random;

            progress_ = progress;

            parentStatistics_ = trainingContext_.GetStatisticsAggregator();

            leftChildStatistics_  = trainingContext_.GetStatisticsAggregator();
            rightChildStatistics_ = trainingContext_.GetStatisticsAggregator();

            responses_ = new float[data.Count()];

            indices_ = new int[data.Count()];
            for (int i = 0; i < indices_.Length; i++)
            {
                indices_[i] = i;
            }

            threadLocals_ = new ThreadLocalData[maxThreads_];
            for (int threadIndex = 0; threadIndex < maxThreads_; threadIndex++)
            {
                threadLocals_[threadIndex] = new ThreadLocalData(random_, trainingContext_, parameters_, data_);
            }
        }
            public ThreadLocalData(Random random, ITrainingContext <F, S> trainingContext_, TrainingParameters parameters, IDataPointCollection data)
            {
                parentStatistics_ = trainingContext_.GetStatisticsAggregator();

                leftChildStatistics_  = trainingContext_.GetStatisticsAggregator();
                rightChildStatistics_ = trainingContext_.GetStatisticsAggregator();

                partitionStatistics_ = new S[parameters.NumberOfCandidateThresholdsPerFeature + 1];
                for (int i = 0; i < parameters.NumberOfCandidateThresholdsPerFeature + 1; i++)
                {
                    partitionStatistics_[i] = trainingContext_.GetStatisticsAggregator();
                }

                responses_ = new float[data.Count()];

                random_ = new Random(random.Next());
            }
Esempio n. 8
0
        static void Main(string[] args)
        {
            if (args.Length == 0 || args[0] == "/?" || args[0].ToLower() == "help")
            {
                DisplayHelp();
                return;
            }

            // These command line parameters are reused over several command line modes...
            StringParameter       trainingDataPath = new StringParameter("path", "Path of file containing training data.");
            NaturalParameter      T             = new NaturalParameter("t", "No. of trees in the forest (default = {0}).", 10);
            NaturalParameter      D             = new NaturalParameter("d", "Maximum tree levels (default = {0}).", 10, 20);
            NaturalParameter      F             = new NaturalParameter("f", "No. of candidate feature responses per decision node (default = {0}).", 10);
            NaturalParameter      L             = new NaturalParameter("l", "No. of candidate thresholds per feature response (default = {0}).", 1);
            SingleParameter       a             = new SingleParameter("a", "The number of 'effective' prior observations (default = {0}).", true, false, 10.0f);
            SingleParameter       b             = new SingleParameter("b", "The variance of the effective observations (default = {0}).", true, true, 400.0f);
            SimpleSwitchParameter verboseSwitch = new SimpleSwitchParameter("Enables verbose progress indication.");
            SingleParameter       plotPaddingX  = new SingleParameter("padx", "Pad plot horizontally (default = {0}).", true, false, 0.1f);
            SingleParameter       plotPaddingY  = new SingleParameter("pady", "Pad plot vertically (default = {0}).", true, false, 0.1f);
            EnumParameter         split         = new EnumParameter(
                "s",
                "Specify what kind of split function to use (default = {0}).",
                new string[] { "axis", "linear" },
                new string[] { "axis-aligned split", "linear split" },
                "axis");

            // Behaviour depends on command line mode...
            string mode = args[0].ToLower(); // first argument defines the command line mode

            if (mode == "clas" || mode == "class")
            {
                #region Supervised classification
                CommandLineParser parser = new CommandLineParser();

                parser.Command = "SW " + mode.ToUpper();

                parser.AddArgument(trainingDataPath);
                parser.AddSwitch("T", T);
                parser.AddSwitch("D", D);
                parser.AddSwitch("F", F);
                parser.AddSwitch("L", L);
                parser.AddSwitch("SPLIT", split);

                parser.AddSwitch("PADX", plotPaddingX);
                parser.AddSwitch("PADY", plotPaddingY);
                parser.AddSwitch("VERBOSE", verboseSwitch);

                // Default values up above should be fine here.

                if (args.Length == 1)
                {
                    parser.PrintHelp();
                    DisplayTextFiles(CLAS_DATA_PATH);
                    return;
                }

                if (parser.Parse(args, 1) == false)
                {
                    return;
                }

                TrainingParameters trainingParameters = new TrainingParameters()
                {
                    MaxDecisionLevels                     = D.Value - 1,
                    NumberOfCandidateFeatures             = F.Value,
                    NumberOfCandidateThresholdsPerFeature = L.Value,
                    NumberOfTrees = T.Value,
                    Verbose       = verboseSwitch.Used
                };

                PointF plotDilation = new PointF(plotPaddingX.Value, plotPaddingY.Value);

                DataPointCollection trainingData = LoadTrainingData(
                    trainingDataPath.Value,
                    CLAS_DATA_PATH,
                    2,
                    DataDescriptor.HasClassLabels);

                if (split.Value == "linear")
                {
                    Forest <LinearFeatureResponse2d, HistogramAggregator> forest = ClassificationExample.Train(
                        trainingData,
                        new LinearFeatureFactory(),
                        trainingParameters);

                    using (Bitmap result = ClassificationExample.Visualize(forest, trainingData, new Size(300, 300), plotDilation))
                    {
                        ShowVisualizationImage(result);
                    }
                }
                else if (split.Value == "axis")
                {
                    Forest <AxisAlignedFeatureResponse, HistogramAggregator> forest = ClassificationExample.Train(
                        trainingData,
                        new AxisAlignedFeatureFactory(),
                        trainingParameters);

                    using (Bitmap result = ClassificationExample.Visualize(forest, trainingData, new Size(300, 300), plotDilation))
                    {
                        ShowVisualizationImage(result);
                    }
                }
                #endregion
            }
            else if (mode == "density")
            {
                #region Density Estimation
                CommandLineParser parser = new CommandLineParser();

                parser.Command = "SW " + mode.ToUpper();

                parser.AddArgument(trainingDataPath);
                parser.AddSwitch("T", T);
                parser.AddSwitch("D", D);
                parser.AddSwitch("F", F);
                parser.AddSwitch("L", L);

                // For density estimation (and semi-supervised learning) we add
                // a command line option to set the hyperparameters of the prior.
                parser.AddSwitch("a", a);
                parser.AddSwitch("b", b);

                parser.AddSwitch("PADX", plotPaddingX);
                parser.AddSwitch("PADY", plotPaddingY);
                parser.AddSwitch("VERBOSE", verboseSwitch);

                // Override default values for command line options.
                T.Value = 1;
                D.Value = 3;
                F.Value = 5;
                L.Value = 1;
                a.Value = 0;
                b.Value = 900;

                if (args.Length == 1)
                {
                    parser.PrintHelp();
                    DisplayTextFiles(DENSITY_DATA_PATH);
                    return;
                }

                if (parser.Parse(args, 1) == false)
                {
                    return;
                }

                TrainingParameters parameters = new TrainingParameters()
                {
                    MaxDecisionLevels                     = D.Value - 1,
                    NumberOfCandidateFeatures             = F.Value,
                    NumberOfCandidateThresholdsPerFeature = L.Value,
                    NumberOfTrees = T.Value,
                    Verbose       = verboseSwitch.Used
                };

                DataPointCollection trainingData = LoadTrainingData(
                    trainingDataPath.Value,
                    DENSITY_DATA_PATH,
                    2,
                    DataDescriptor.Unadorned);

                Forest <AxisAlignedFeatureResponse, GaussianAggregator2d> forest = DensityEstimationExample.Train(trainingData, parameters, a.Value, b.Value);

                PointF plotDilation = new PointF(plotPaddingX.Value, plotPaddingY.Value);

                using (Bitmap result = DensityEstimationExample.Visualize(forest, trainingData, new Size(300, 300), plotDilation))
                {
                    ShowVisualizationImage(result);
                }
                #endregion
            }
            else if (mode == "ssclas" || mode == "ssclas")
            {
                #region Semi-supervised classification

                CommandLineParser parser = new CommandLineParser();

                parser.Command = "SW " + mode.ToUpper();

                parser.AddArgument(trainingDataPath);
                parser.AddSwitch("T", T);
                parser.AddSwitch("D", D);
                parser.AddSwitch("F", F);
                parser.AddSwitch("L", L);

                parser.AddSwitch("split", split);

                parser.AddSwitch("a", a);
                parser.AddSwitch("b", b);

                EnumParameter plotMode = new EnumParameter(
                    "plot",
                    "Determines what to plot",
                    new string[] { "density", "labels" },
                    new string[] { "plot recovered density estimate", "plot class likelihood" },
                    "labels");
                parser.AddSwitch("plot", plotMode);

                parser.AddSwitch("PADX", plotPaddingX);
                parser.AddSwitch("PADY", plotPaddingY);

                parser.AddSwitch("VERBOSE", verboseSwitch);

                // Override default values for command line options.
                T.Value = 10;
                D.Value = 12 - 1;
                F.Value = 30;
                L.Value = 1;

                if (args.Length == 1)
                {
                    parser.PrintHelp();
                    DisplayTextFiles(SSCLAS_DATA_PATH);
                    return;
                }

                if (parser.Parse(args, 1) == false)
                {
                    return;
                }

                DataPointCollection trainingData = LoadTrainingData(
                    trainingDataPath.Value,
                    SSCLAS_DATA_PATH,
                    2,
                    DataDescriptor.HasClassLabels);

                TrainingParameters parameters = new TrainingParameters()
                {
                    MaxDecisionLevels                     = D.Value - 1,
                    NumberOfCandidateFeatures             = F.Value,
                    NumberOfCandidateThresholdsPerFeature = L.Value,
                    NumberOfTrees = T.Value,
                    Verbose       = verboseSwitch.Used
                };

                Forest <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> forest = SemiSupervisedClassificationExample.Train(
                    trainingData, parameters, a.Value, b.Value);

                PointF plotPadding = new PointF(plotPaddingX.Value, plotPaddingY.Value);

                if (plotMode.Value == "labels")
                {
                    using (Bitmap result = SemiSupervisedClassificationExample.VisualizeLabels(forest, trainingData, new Size(300, 300), plotPadding))
                    {
                        ShowVisualizationImage(result);
                    }
                }
                else if (plotMode.Value == "density")
                {
                    using (Bitmap result = SemiSupervisedClassificationExample.VisualizeDensity(forest, trainingData, new Size(300, 300), plotPadding))
                    {
                        ShowVisualizationImage(result);
                    }
                }
                #endregion
            }
            else if (mode == "regression")
            {
                #region Regression
                CommandLineParser parser = new CommandLineParser();
                parser.Command = "SW " + mode.ToUpper();

                parser.AddArgument(trainingDataPath);
                parser.AddSwitch("T", T);
                parser.AddSwitch("D", D);
                parser.AddSwitch("F", F);
                parser.AddSwitch("L", L);

                parser.AddSwitch("PADX", plotPaddingX);
                parser.AddSwitch("PADY", plotPaddingY);
                parser.AddSwitch("VERBOSE", verboseSwitch);

                // Override default values for command line options
                T.Value = 10;
                D.Value = 2;
                a.Value = 0; // prior turned off by default
                b.Value = 900;

                if (args.Length == 1)
                {
                    parser.PrintHelp();
                    DisplayTextFiles(REGRESSION_DATA_PATH);
                    return;
                }

                if (parser.Parse(args, 1) == false)
                {
                    return;
                }

                RegressionExample regressionDemo = new RegressionExample();

                regressionDemo.PlotDilation.X = plotPaddingX.Value;
                regressionDemo.PlotDilation.Y = plotPaddingY.Value;

                regressionDemo.TrainingParameters = new TrainingParameters()
                {
                    MaxDecisionLevels                     = D.Value - 1,
                    NumberOfCandidateFeatures             = F.Value,
                    NumberOfCandidateThresholdsPerFeature = L.Value,
                    NumberOfTrees = T.Value,
                    Verbose       = verboseSwitch.Used
                };

                DataPointCollection trainingData = LoadTrainingData(
                    trainingDataPath.Value,
                    REGRESSION_DATA_PATH,
                    1,
                    DataDescriptor.HasTargetValues);

                using (Bitmap result = regressionDemo.Run(trainingData))
                {
                    ShowVisualizationImage(result);
                }
                #endregion
            }
            else
            {
                Console.WriteLine("Unrecognized command line argument, try SW HELP.");
                return;
            }
        }
        public static Forest <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> Train(
            DataPointCollection trainingData,
            TrainingParameters parameters,
            double a_,
            double b_)
        {
            // Train the forest
            Console.WriteLine("Training the forest...");

            Random random = new Random();

            ITrainingContext <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> classificationContext
                = new SemiSupervisedClassificationTrainingContext(trainingData.CountClasses(), random, a_, b_);
            var forest = ForestTrainer <LinearFeatureResponse2d, SemiSupervisedClassificationStatisticsAggregator> .TrainForest(
                random,
                parameters,
                classificationContext,
                trainingData);

            // Label transduction to unlabelled leaves from nearest labelled leaf
            List <int> unlabelledLeafIndices = null;
            List <int> labelledLeafIndices   = null;

            int[]      closestLabelledLeafIndices = null;
            List <int> leafIndices = null;

            for (int t = 0; t < forest.TreeCount; t++)
            {
                var tree = forest.GetTree(t);
                leafIndices = new List <int>();

                unlabelledLeafIndices = new List <int>();
                labelledLeafIndices   = new List <int>();

                for (int n = 0; n < tree.NodeCount; n++)
                {
                    if (tree.GetNode(n).IsLeaf)
                    {
                        if (tree.GetNode(n).TrainingDataStatistics.HistogramAggregator.SampleCount == 0)
                        {
                            unlabelledLeafIndices.Add(leafIndices.Count);
                        }
                        else
                        {
                            labelledLeafIndices.Add(leafIndices.Count);
                        }

                        leafIndices.Add(n);
                    }
                }

                // Build an upper triangular matrix of inter-leaf distances
                float[,] interLeafDistances = new float[leafIndices.Count, leafIndices.Count];
                for (int i = 0; i < leafIndices.Count; i++)
                {
                    for (int j = i + 1; j < leafIndices.Count; j++)
                    {
                        SemiSupervisedClassificationStatisticsAggregator a = tree.GetNode(leafIndices[i]).TrainingDataStatistics;
                        SemiSupervisedClassificationStatisticsAggregator b = tree.GetNode(leafIndices[j]).TrainingDataStatistics;
                        GaussianPdf2d x = a.GaussianAggregator2d.GetPdf();
                        GaussianPdf2d y = b.GaussianAggregator2d.GetPdf();

                        interLeafDistances[i, j] = (float)(Math.Max(
                                                               x.GetNegativeLogProbability((float)(y.MeanX), (float)(y.MeanY)),
                                                               +y.GetNegativeLogProbability((float)(x.MeanX), (float)(x.MeanY))));
                    }
                }

                // Find shortest paths between all pairs of nodes in the graph of leaf nodes
                FloydWarshall pathFinder = new FloydWarshall(interLeafDistances);

                // Find the closest labelled leaf to each unlabelled leaf
                float[] minDistances = new float[unlabelledLeafIndices.Count];
                closestLabelledLeafIndices = new int[unlabelledLeafIndices.Count];
                for (int i = 0; i < minDistances.Length; i++)
                {
                    minDistances[i] = float.PositiveInfinity;
                    closestLabelledLeafIndices[i] = -1; // unused so deliberately invalid
                }

                for (int l = 0; l < labelledLeafIndices.Count; l++)
                {
                    for (int u = 0; u < unlabelledLeafIndices.Count; u++)
                    {
                        if (pathFinder.GetMinimumDistance(unlabelledLeafIndices[u], labelledLeafIndices[l]) < minDistances[u])
                        {
                            minDistances[u] = pathFinder.GetMinimumDistance(unlabelledLeafIndices[u], labelledLeafIndices[l]);
                            closestLabelledLeafIndices[u] = leafIndices[labelledLeafIndices[l]];
                        }
                    }
                }

                // Propagate class probability distributions to each unlabelled
                // leaf from its nearest labelled leaf.
                for (int u = 0; u < unlabelledLeafIndices.Count; u++)
                {
                    // Unhelpfully, C# only allows us to pass value types by value
                    // so Tree.GetNode() returns only a COPY of the Node. We update
                    // this copy and then copy it back over the top of the
                    // original via Tree.SetNode().

                    // The C++ version is a lot better!

                    var unlabelledLeafCopy = tree.GetNode(leafIndices[unlabelledLeafIndices[u]]);
                    var labelledLeafCopy   = tree.GetNode(closestLabelledLeafIndices[u]);

                    unlabelledLeafCopy.TrainingDataStatistics.HistogramAggregator
                        = (HistogramAggregator)(labelledLeafCopy.TrainingDataStatistics.HistogramAggregator.DeepClone());

                    tree.SetNode(leafIndices[unlabelledLeafIndices[u]], unlabelledLeafCopy);
                }
            }

            return(forest);
        }