Exemplo n.º 1
0
 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
     : base(original, cloner)
 {
     valueEvaluationCache      = new Dictionary <int, double>(original.valueEvaluationCache);
     classValueEvaluationCache = new Dictionary <int, double>(original.classValueEvaluationCache);
 }
 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
   : base(original, cloner) {
   valueEvaluationCache = new Dictionary<int, double>(original.valueEvaluationCache);
   classValueEvaluationCache = new Dictionary<int, double>(original.classValueEvaluationCache);
 }
    protected override void Run(CancellationToken cancellationToken) {
      // Set up the algorithm
      if (SetSeedRandomly) Seed = new System.Random().Next();

      // Set up the results display
      var iterations = new IntValue(0);
      Results.Add(new Result("Iterations", iterations));

      var table = new DataTable("Qualities");
      table.Rows.Add(new DataRow("Loss (train)"));
      table.Rows.Add(new DataRow("Loss (test)"));
      Results.Add(new Result("Qualities", table));
      var curLoss = new DoubleValue();
      Results.Add(new Result("Loss (train)", curLoss));

      // init
      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
      var lossFunction = LossFunctionParameter.Value;
      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu);

      var updateInterval = UpdateIntervalParameter.Value.Value;
      // Loop until iteration limit reached or canceled.
      for (int i = 0; i < Iterations; i++) {
        cancellationToken.ThrowIfCancellationRequested();

        GradientBoostedTreesAlgorithmStatic.MakeStep(state);

        // iteration results
        if (i % updateInterval == 0) {
          curLoss.Value = state.GetTrainLoss();
          table.Rows["Loss (train)"].Values.Add(curLoss.Value);
          table.Rows["Loss (test)"].Values.Add(state.GetTestLoss());
          iterations.Value = i;
        }
      }

      // final results
      iterations.Value = Iterations;
      curLoss.Value = state.GetTrainLoss();
      table.Rows["Loss (train)"].Values.Add(curLoss.Value);
      table.Rows["Loss (test)"].Values.Add(state.GetTestLoss());

      // produce variable relevance
      var orderedImpacts = state.GetVariableRelevance().Select(t => new { name = t.Key, impact = t.Value }).ToList();

      var impacts = new DoubleMatrix();
      var matrix = impacts as IStringConvertibleMatrix;
      matrix.Rows = orderedImpacts.Count;
      matrix.RowNames = orderedImpacts.Select(x => x.name);
      matrix.Columns = 1;
      matrix.ColumnNames = new string[] { "Relative variable relevance" };

      int rowIdx = 0;
      foreach (var p in orderedImpacts) {
        matrix.SetValue(string.Format("{0:N2}", p.impact), rowIdx++, 0);
      }

      Results.Add(new Result("Variable relevance", impacts));
      Results.Add(new Result("Loss (test)", new DoubleValue(state.GetTestLoss())));

      // produce solution 
      if (CreateSolution) {
        var model = state.GetModel();

        // for logistic regression we produce a classification solution
        if (lossFunction is LogisticRegressionLoss) {
          var classificationModel = new DiscriminantFunctionClassificationModel(model,
            new AccuracyMaximizationThresholdCalculator());
          var classificationProblemData = new ClassificationProblemData(problemData.Dataset,
            problemData.AllowedInputVariables, problemData.TargetVariable, problemData.Transformations);
          classificationModel.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);

          var classificationSolution = new DiscriminantFunctionClassificationSolution(classificationModel, classificationProblemData);
          Results.Add(new Result("Solution", classificationSolution));
        } else {
          // otherwise we produce a regression solution
          Results.Add(new Result("Solution", new RegressionSolution(model, problemData)));
        }
      }
    }