public override IOperation Apply()
        {
            if (ModelParameter.ActualValue != null && CreateSolutionParameter.ActualValue.Value == true)
            {
                var m = (IGaussianProcessModel)ModelParameter.ActualValue.Clone();
                m.FixParameters();
                var data  = (IClassificationProblemData)ProblemDataParameter.ActualValue.Clone();
                var model = new DiscriminantFunctionClassificationModel(m, new NormalDistributionCutPointsThresholdCalculator());
                model.RecalculateModelParameters(data, data.TrainingIndices);
                var s = model.CreateDiscriminantFunctionClassificationSolution(data);

                SolutionParameter.ActualValue = s;
                var results = ResultsParameter.ActualValue;
                if (!results.ContainsKey(SolutionParameterName))
                {
                    results.Add(new Result(SolutionParameterName, "The Gaussian process classification solution", s));
                    results.Add(new Result(TrainingAccuracyResultName,
                                           "The accuracy of the Gaussian process solution on the training partition.",
                                           new DoubleValue(s.TrainingAccuracy)));
                    results.Add(new Result(TestAccuracyResultName,
                                           "The accuracy of the Gaussian process solution on the test partition.",
                                           new DoubleValue(s.TestAccuracy)));
                }
                else
                {
                    results[SolutionParameterName].Value      = s;
                    results[TrainingAccuracyResultName].Value = new DoubleValue(s.TrainingAccuracy);
                    results[TestAccuracyResultName].Value     = new DoubleValue(s.TestAccuracy);
                }
            }
            return(base.Apply());
        }
예제 #2
0
        protected override void Run(CancellationToken cancellationToken)
        {
            // Set up the algorithm
            if (SetSeedRandomly)
            {
                Seed = new System.Random().Next();
            }

            // Set up the results display
            var iterations = new IntValue(0);

            Results.Add(new Result("Iterations", iterations));

            var table = new DataTable("Qualities");

            table.Rows.Add(new DataRow("Loss (train)"));
            table.Rows.Add(new DataRow("Loss (test)"));
            Results.Add(new Result("Qualities", table));
            var curLoss = new DoubleValue();

            Results.Add(new Result("Loss (train)", curLoss));

            // init
            var problemData  = (IRegressionProblemData)Problem.ProblemData.Clone();
            var lossFunction = LossFunctionParameter.Value;
            var state        = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu);

            var updateInterval = UpdateIntervalParameter.Value.Value;

            // Loop until iteration limit reached or canceled.
            for (int i = 0; i < Iterations; i++)
            {
                cancellationToken.ThrowIfCancellationRequested();

                GradientBoostedTreesAlgorithmStatic.MakeStep(state);

                // iteration results
                if (i % updateInterval == 0)
                {
                    curLoss.Value = state.GetTrainLoss();
                    table.Rows["Loss (train)"].Values.Add(curLoss.Value);
                    table.Rows["Loss (test)"].Values.Add(state.GetTestLoss());
                    iterations.Value = i;
                }
            }

            // final results
            iterations.Value = Iterations;
            curLoss.Value    = state.GetTrainLoss();
            table.Rows["Loss (train)"].Values.Add(curLoss.Value);
            table.Rows["Loss (test)"].Values.Add(state.GetTestLoss());

            // produce variable relevance
            var orderedImpacts = state.GetVariableRelevance().Select(t => new { name = t.Key, impact = t.Value }).ToList();

            var impacts = new DoubleMatrix();
            var matrix  = impacts as IStringConvertibleMatrix;

            matrix.Rows        = orderedImpacts.Count;
            matrix.RowNames    = orderedImpacts.Select(x => x.name);
            matrix.Columns     = 1;
            matrix.ColumnNames = new string[] { "Relative variable relevance" };

            int rowIdx = 0;

            foreach (var p in orderedImpacts)
            {
                matrix.SetValue(string.Format("{0:N2}", p.impact), rowIdx++, 0);
            }

            Results.Add(new Result("Variable relevance", impacts));
            Results.Add(new Result("Loss (test)", new DoubleValue(state.GetTestLoss())));

            // produce solution
            if (CreateSolution)
            {
                var model = state.GetModel();

                // for logistic regression we produce a classification solution
                if (lossFunction is LogisticRegressionLoss)
                {
                    var classificationModel = new DiscriminantFunctionClassificationModel(model,
                                                                                          new AccuracyMaximizationThresholdCalculator());
                    var classificationProblemData = new ClassificationProblemData(problemData.Dataset,
                                                                                  problemData.AllowedInputVariables, problemData.TargetVariable, problemData.Transformations);
                    classificationModel.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);

                    var classificationSolution = new DiscriminantFunctionClassificationSolution(classificationModel, classificationProblemData);
                    Results.Add(new Result("Solution", classificationSolution));
                }
                else
                {
                    // otherwise we produce a regression solution
                    Results.Add(new Result("Solution", new RegressionSolution(model, problemData)));
                }
            }
        }