private readonly double _sigmoidParameter; // Parameter for scaling the loss public ObjectiveImpl( Dataset trainSet, bool[] trainSetLabels, double learningRate, double shrinkage, double sigmoidParameter, bool unbalancedSets, double maxTreeOutput, int getDerivativesSampleRate, bool bestStepRankingRegressionTrees, int rngSeed, IParallelTraining parallelTraining) : base( trainSet, learningRate, shrinkage, maxTreeOutput, getDerivativesSampleRate, bestStepRankingRegressionTrees, rngSeed) { _sigmoidParameter = sigmoidParameter; _labels = trainSetLabels; _unbalancedSets = unbalancedSets; if (_unbalancedSets) { BinaryClassificationTest.ComputeExampleCounts(_labels, out _npos, out _nneg); Contracts.Check(_nneg > 0 && _npos > 0, "Only one class in training set."); } _parallelTraining = parallelTraining; }
protected override void DefinePruningTest() { var validTest = new BinaryClassificationTest(ValidSetScore, ConvertTargetsToBool(ValidSet.Targets), _sigmoidParameter); // As per FastTreeClassification.ConstructOptimizationAlgorithm() PruningLossIndex = Args.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/; PruningTest = new TestHistory(validTest, PruningLossIndex); }
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch) { OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch); if (FastTreeTrainerOptions.UseLineSearch) { var lossCalculator = new BinaryClassificationTest(optimizationAlgorithm.TrainingScores, _trainSetLabels, _sigmoidParameter); // REVIEW: we should makeloss indices an enum in BinaryClassificationTest optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, FastTreeTrainerOptions.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/, FastTreeTrainerOptions.MaximumNumberOfLineSearchSteps, FastTreeTrainerOptions.MinimumStepSize); } return(optimizationAlgorithm); }
private protected override void InitializeTests() { //Always compute training L1/L2 errors TrainTest = new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter); Tests.Add(TrainTest); if (ValidSet != null) { ValidTest = new BinaryClassificationTest(ConstructScoreTracker(ValidSet), GetClassificationLabelsFromRatings(ValidSet).ToArray(), _sigmoidParameter); Tests.Add(ValidTest); } //If external label is missing use Rating column for L1/L2 error //The values may not make much sense if regression value is not an actual label value if (TestSets != null) { for (int t = 0; t < TestSets.Length; ++t) { bool[] labels = GetClassificationLabelsFromRatings(TestSets[t]).ToArray(); Tests.Add(new BinaryClassificationTest(ConstructScoreTracker(TestSets[t]), labels, _sigmoidParameter)); } } if (FastTreeTrainerOptions.EnablePruning && ValidSet != null) { if (!FastTreeTrainerOptions.UseTolerantPruning) { //use simple early stopping condition PruningTest = new TestHistory(ValidTest, 0); } else { //use tollerant stopping condition PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold); } } }