コード例 #1
0
            private readonly double _sigmoidParameter; // Parameter for scaling the loss

            public ObjectiveImpl(
                Dataset trainSet,
                bool[] trainSetLabels,
                double learningRate,
                double shrinkage,
                double sigmoidParameter,
                bool unbalancedSets,
                double maxTreeOutput,
                int getDerivativesSampleRate,
                bool bestStepRankingRegressionTrees,
                int rngSeed,
                IParallelTraining parallelTraining)
                : base(
                    trainSet,
                    learningRate,
                    shrinkage,
                    maxTreeOutput,
                    getDerivativesSampleRate,
                    bestStepRankingRegressionTrees,
                    rngSeed)
            {
                _sigmoidParameter = sigmoidParameter;
                _labels           = trainSetLabels;
                _unbalancedSets   = unbalancedSets;
                if (_unbalancedSets)
                {
                    BinaryClassificationTest.ComputeExampleCounts(_labels, out _npos, out _nneg);
                    Contracts.Check(_nneg > 0 && _npos > 0, "Only one class in training set.");
                }
                _parallelTraining = parallelTraining;
            }
コード例 #2
0
        protected override void DefinePruningTest()
        {
            var validTest = new BinaryClassificationTest(ValidSetScore,
                                                         ConvertTargetsToBool(ValidSet.Targets), _sigmoidParameter);

            // As per FastTreeClassification.ConstructOptimizationAlgorithm()
            PruningLossIndex = Args.UnbalancedSets ? 3 /*Unbalanced  sets  loss*/ : 1 /*normal loss*/;
            PruningTest      = new TestHistory(validTest, PruningLossIndex);
        }
コード例 #3
0
        private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
        {
            OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);

            if (FastTreeTrainerOptions.UseLineSearch)
            {
                var lossCalculator = new BinaryClassificationTest(optimizationAlgorithm.TrainingScores, _trainSetLabels, _sigmoidParameter);
                // REVIEW: we should makeloss indices an enum in BinaryClassificationTest
                optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, FastTreeTrainerOptions.UnbalancedSets ? 3 /*Unbalanced  sets  loss*/ : 1 /*normal loss*/, FastTreeTrainerOptions.MaximumNumberOfLineSearchSteps, FastTreeTrainerOptions.MinimumStepSize);
            }
            return(optimizationAlgorithm);
        }
コード例 #4
0
        private protected override void InitializeTests()
        {
            //Always compute training L1/L2 errors
            TrainTest = new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter);
            Tests.Add(TrainTest);

            if (ValidSet != null)
            {
                ValidTest = new BinaryClassificationTest(ConstructScoreTracker(ValidSet),
                                                         GetClassificationLabelsFromRatings(ValidSet).ToArray(), _sigmoidParameter);
                Tests.Add(ValidTest);
            }

            //If external label is missing use Rating column for L1/L2 error
            //The values may not make much sense if regression value is not an actual label value
            if (TestSets != null)
            {
                for (int t = 0; t < TestSets.Length; ++t)
                {
                    bool[] labels = GetClassificationLabelsFromRatings(TestSets[t]).ToArray();
                    Tests.Add(new BinaryClassificationTest(ConstructScoreTracker(TestSets[t]), labels, _sigmoidParameter));
                }
            }

            if (FastTreeTrainerOptions.EnablePruning && ValidSet != null)
            {
                if (!FastTreeTrainerOptions.UseTolerantPruning)
                {
                    //use simple early stopping condition
                    PruningTest = new TestHistory(ValidTest, 0);
                }
                else
                {
                    //use tollerant stopping condition
                    PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold);
                }
            }
        }