public void TestDecisionTreePersistence()
        {
            var provider   = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
            var instance   = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
            var regProblem = new RegressionProblem();

            regProblem.Load(provider.LoadData(instance));
            var problemData = regProblem.ProblemData;
            var state       = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1);

            GradientBoostedTreesAlgorithmStatic.MakeStep(state);

            var model   = ((IGradientBoostedTreesModel)state.GetModel());
            var treeM   = model.Models.Skip(1).First();
            var origStr = treeM.ToString();

            using (var memStream = new MemoryStream()) {
                Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream);
                var buf = memStream.GetBuffer();
                using (var restoreStream = new MemoryStream(buf)) {
                    var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream);
                    var restoredStr  = restoredTree.ToString();
                    Assert.AreEqual(origStr, restoredStr);
                }
            }
        }
        private void BuildTree(double[,] xy, string[] allVariables, int maxSize)
        {
            int nRows         = xy.GetLength(0);
            var allowedInputs = allVariables.Skip(1);
            var dataset       = new Dataset(allVariables, xy);
            var problemData   = new RegressionProblemData(dataset, allowedInputs, allVariables.First());

            problemData.TrainingPartition.Start = 0;
            problemData.TrainingPartition.End   = nRows;
            problemData.TestPartition.Start     = nRows;
            problemData.TestPartition.End       = nRows;
            var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
            var model    = solution.Model;
            var treeM    = model.Models.Skip(1).First() as RegressionTreeModel;

            Console.WriteLine(treeM.ToString());
            Console.WriteLine();
        }
        public void TestDecisionTreePartialDependence()
        {
            var provider   = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
            var instance   = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
            var regProblem = new RegressionProblem();

            regProblem.Load(provider.LoadData(instance));
            var problemData = regProblem.ProblemData;
            var state       = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02);

            for (int i = 0; i < 1000; i++)
            {
                GradientBoostedTreesAlgorithmStatic.MakeStep(state);
            }


            var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First();

            Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value);
            var model = ((IGradientBoostedTreesModel)state.GetModel());
            var treeM = model.Models.Skip(1).First();

            Console.WriteLine(treeM.ToString());
            Console.WriteLine();

            var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray();
            var ds = new ModifiableDataset(new string[] { mostImportantVar.Key },
                                           new IList[] { mostImportantVarValues.ToList <double>() });

            var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray();

            for (int i = 0; i < mostImportantVarValues.Length; i += 10)
            {
                Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]);
            }
        }
Exemple #4
0
        protected override void Run(CancellationToken cancellationToken)
        {
            // Set up the algorithm
            if (SetSeedRandomly)
            {
                Seed = new System.Random().Next();
            }

            // Set up the results display
            var iterations = new IntValue(0);

            Results.Add(new Result("Iterations", iterations));

            var table = new DataTable("Qualities");

            table.Rows.Add(new DataRow("Loss (train)"));
            table.Rows.Add(new DataRow("Loss (test)"));
            Results.Add(new Result("Qualities", table));
            var curLoss = new DoubleValue();

            Results.Add(new Result("Loss (train)", curLoss));

            // init
            var problemData  = (IRegressionProblemData)Problem.ProblemData.Clone();
            var lossFunction = LossFunctionParameter.Value;
            var state        = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu);

            var updateInterval = UpdateIntervalParameter.Value.Value;

            // Loop until iteration limit reached or canceled.
            for (int i = 0; i < Iterations; i++)
            {
                cancellationToken.ThrowIfCancellationRequested();

                GradientBoostedTreesAlgorithmStatic.MakeStep(state);

                // iteration results
                if (i % updateInterval == 0)
                {
                    curLoss.Value = state.GetTrainLoss();
                    table.Rows["Loss (train)"].Values.Add(curLoss.Value);
                    table.Rows["Loss (test)"].Values.Add(state.GetTestLoss());
                    iterations.Value = i;
                }
            }

            // final results
            iterations.Value = Iterations;
            curLoss.Value    = state.GetTrainLoss();
            table.Rows["Loss (train)"].Values.Add(curLoss.Value);
            table.Rows["Loss (test)"].Values.Add(state.GetTestLoss());

            // produce variable relevance
            var orderedImpacts = state.GetVariableRelevance().Select(t => new { name = t.Key, impact = t.Value }).ToList();

            var impacts = new DoubleMatrix();
            var matrix  = impacts as IStringConvertibleMatrix;

            matrix.Rows        = orderedImpacts.Count;
            matrix.RowNames    = orderedImpacts.Select(x => x.name);
            matrix.Columns     = 1;
            matrix.ColumnNames = new string[] { "Relative variable relevance" };

            int rowIdx = 0;

            foreach (var p in orderedImpacts)
            {
                matrix.SetValue(string.Format("{0:N2}", p.impact), rowIdx++, 0);
            }

            Results.Add(new Result("Variable relevance", impacts));
            Results.Add(new Result("Loss (test)", new DoubleValue(state.GetTestLoss())));

            // produce solution
            if (CreateSolution)
            {
                var model = state.GetModel();

                // for logistic regression we produce a classification solution
                if (lossFunction is LogisticRegressionLoss)
                {
                    var classificationModel = new DiscriminantFunctionClassificationModel(model,
                                                                                          new AccuracyMaximizationThresholdCalculator());
                    var classificationProblemData = new ClassificationProblemData(problemData.Dataset,
                                                                                  problemData.AllowedInputVariables, problemData.TargetVariable, problemData.Transformations);
                    classificationModel.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);

                    var classificationSolution = new DiscriminantFunctionClassificationSolution(classificationModel, classificationProblemData);
                    Results.Add(new Result("Solution", classificationSolution));
                }
                else
                {
                    // otherwise we produce a regression solution
                    Results.Add(new Result("Solution", new RegressionSolution(model, problemData)));
                }
            }
        }
 private IGradientBoostedTreesModel RecalculateModel()
 {
     return(GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model);
 }