Ejemplo n.º 1
0
        /// <summary>
        /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
        /// </summary>
        /// <param name="problemData">The regression problem data</param>
        /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
        /// <param name="seed">The random seed (required by the random forest model)</param>
        /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
        public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary <string, IEnumerable <double> > parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1)
        {
            var         setters              = parameterRanges.Keys.Select(GenerateSetter).ToList();
            var         crossProduct         = parameterRanges.Values.CartesianProduct();
            double      bestOutOfBagRmsError = double.MaxValue;
            RFParameter bestParameters       = new RFParameter();

            var locker = new object();

            Parallel.ForEach(crossProduct, new ParallelOptions {
                MaxDegreeOfParallelism = maxDegreeOfParallelism
            }, parameterCombination => {
                var parameterValues = parameterCombination.ToList();
                var parameters      = new RFParameter();
                for (int i = 0; i < setters.Count; ++i)
                {
                    setters[i](parameters, parameterValues[i]);
                }
                double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
                RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);

                lock (locker) {
                    if (bestOutOfBagRmsError > outOfBagRmsError)
                    {
                        bestOutOfBagRmsError = outOfBagRmsError;
                        bestParameters       = (RFParameter)parameters.Clone();
                    }
                }
            });
            return(bestParameters);
        }
 public static RandomForestModel CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees,
                                                                   double r, double m, int seed,
                                                                   out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError)
 {
     return(RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed,
                                                    rmsError: out rmsError, avgRelError: out avgRelError, outOfBagRmsError: out outOfBagRmsError, outOfBagAvgRelError: out outOfBagAvgRelError));
 }
Ejemplo n.º 3
0
        private static void CrossValidate(IRegressionProblemData problemData, Tuple <IEnumerable <int>, IEnumerable <int> >[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse)
        {
            avgTestMse = 0;
            var ds             = problemData.Dataset;
            var targetVariable = GetTargetVariableName(problemData);

            foreach (var tuple in partitions)
            {
                double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
                var    trainingRandomForestPartition = tuple.Item1;
                var    testRandomForestPartition     = tuple.Item2;
                var    model           = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
                var    estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
                var    targetValues    = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
                OnlineCalculatorError calculatorError;
                double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
                if (calculatorError != OnlineCalculatorError.None)
                {
                    mse = double.NaN;
                }
                avgTestMse += mse;
            }
            avgTestMse /= partitions.Length;
        }