protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { Contracts.CheckValue(ch, nameof(ch)); OptimizationAlgorithm optimizationAlgorithm; IGradientAdjuster gradientWrapper = MakeGradientWrapper(ch); switch (Args.OptimizationAlgorithm) { case BoostedTreeArgs.OptimizationAlgorithmType.GradientDescent: optimizationAlgorithm = new GradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper); break; case BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent: optimizationAlgorithm = new AcceleratedGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper); break; case BoostedTreeArgs.OptimizationAlgorithmType.ConjugateGradientDescent: optimizationAlgorithm = new ConjugateGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper); break; default: throw ch.Except("Unknown optimization algorithm '{0}'", Args.OptimizationAlgorithm); } optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch); optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch); optimizationAlgorithm.Smoothing = Args.Smoothing; optimizationAlgorithm.DropoutRate = Args.DropoutRate; optimizationAlgorithm.DropoutRng = new Random(Args.RngSeed); optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph; return(optimizationAlgorithm); }
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { Host.CheckValue(ch, nameof(ch)); IGradientAdjuster gradientWrapper = MakeGradientWrapper(ch); var optimizationAlgorithm = new RandomForestOptimizer(Ensemble, TrainSet, InitTrainScores, gradientWrapper); optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch); optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch); optimizationAlgorithm.Smoothing = Args.Smoothing; // No notion of dropout for non-boosting applications. optimizationAlgorithm.DropoutRate = 0; optimizationAlgorithm.DropoutRng = null; optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph; return(optimizationAlgorithm); }
public AcceleratedGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper) : base(ensemble, trainData, initTrainScores, gradientWrapper) { UseFastTrainingScoresUpdate = false; }
// REVIEW: When the FastTree appliation is decoupled with tree learner and boosting logic, this class should be removed. public RandomForestOptimizer(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper) : base(ensemble, trainData, initTrainScores, gradientWrapper) { _gradientWrapper = gradientWrapper; }
internal GradientDescent(InternalTreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper) : base(ensemble, trainData, initTrainScores) { _gradientWrapper = gradientWrapper; _treeScores = new List <double[]>(); }
public ConjugateGradientDescent(InternalTreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper) : base(ensemble, trainData, initTrainScores, gradientWrapper) { _currentDk = new double[trainData.NumDocs]; }