protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); if (Args.UseLineSearch) { var lossCalculator = new BinaryClassificationTest(optimizationAlgorithm.TrainingScores, _trainSetLabels, _sigmoidParameter); // REVIEW: we should makeloss indices an enum in BinaryClassificationTest optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, Args.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/, Args.NumPostBracketSteps, Args.MinStepSize); } return(optimizationAlgorithm); }
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); if (Args.UseLineSearch) { var lossCalculator = new RegressionTest(optimizationAlgorithm.TrainingScores); // REVIEW: We should make loss indices an enum in BinaryClassificationTest. optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, Args.NumPostBracketSteps, Args.MinStepSize); } return(optimizationAlgorithm); }
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); if (FastTreeTrainerOptions.UseLineSearch) { var lossCalculator = new RegressionTest(optimizationAlgorithm.TrainingScores); // REVIEW: We should make loss indices an enum in BinaryClassificationTest. // REVIEW: Nope, subcomponent. optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, FastTreeTrainerOptions.MaximumNumberOfLineSearchSteps, FastTreeTrainerOptions.MinimumStepSize); } return(optimizationAlgorithm); }