protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
        {
            OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);

            if (Args.UseLineSearch)
            {
                var lossCalculator = new RegressionTest(optimizationAlgorithm.TrainingScores);
                // REVIEW: We should make loss indices an enum in BinaryClassificationTest.
                // REVIEW: Nope, subcomponent.
                optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, Args.NumPostBracketSteps, Args.MinStepSize);
            }

            return(optimizationAlgorithm);
        }
Ejemplo n.º 2
0
        private void AddFullRegressionTests()
        {
            // Always compute training L1/L2 errors.
            Tests.Add(new RegressionTest(ConstructScoreTracker(TrainSet)));
            RegressionTest validTest = null;

            if (ValidSet != null)
            {
                validTest = new RegressionTest(ConstructScoreTracker(ValidSet));
                Tests.Add(validTest);
            }

            // If external label is missing use Rating column for L1/L2 error.
            // The values may not make much sense if regression value is not an actual label value.
            if (TestSets != null)
            {
                for (int t = 0; t < TestSets.Length; ++t)
                {
                    Tests.Add(new RegressionTest(ConstructScoreTracker(TestSets[t])));
                }
            }
        }