Exemplo n.º 1
0
        protected override void InitializeTests()
        {
            // Initialize regression tests.
            if (Args.TestFrequency != int.MaxValue)
            {
                AddFullRegressionTests();
            }

            if (Args.PrintTestGraph)
            {
                // If FirstTestHistory is null (which means the tests were not intialized due to /tf==infinity),
                // we need initialize first set for graph printing.
                // Adding to a tests would result in printing the results after final iteration.
                if (_firstTestSetHistory == null)
                {
                    var firstTestSetTest = new RegressionTest(ConstructScoreTracker(TestSets[0]));
                    _firstTestSetHistory = new TestHistory(firstTestSetTest, 0);
                }
            }

            if (Args.PrintTrainValidGraph && _trainRegressionTest == null)
            {
                Test trainRegressionTest = new RegressionTest(ConstructScoreTracker(TrainSet));
                _trainRegressionTest = trainRegressionTest;
            }

            if (Args.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0)
            {
                _testRegressionTest = new RegressionTest(ConstructScoreTracker(TestSets[0]));
            }

            // Add early stopping if appropriate.
            TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), Args.EarlyStoppingMetrics);
            if (ValidSet != null)
            {
                ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), Args.EarlyStoppingMetrics);
            }

            if (Args.EnablePruning && ValidTest != null)
            {
                if (Args.UseTolerantPruning) // Use simple early stopping condition.
                {
                    PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold);
                }
                else
                {
                    PruningTest = new TestHistory(ValidTest, 0);
                }
            }
        }
Exemplo n.º 2
0
        private protected override void InitializeTests()
        {
            //Always compute training L1/L2 errors
            TrainTest = new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter);
            Tests.Add(TrainTest);

            if (ValidSet != null)
            {
                ValidTest = new BinaryClassificationTest(ConstructScoreTracker(ValidSet),
                                                         GetClassificationLabelsFromRatings(ValidSet).ToArray(), _sigmoidParameter);
                Tests.Add(ValidTest);
            }

            //If external label is missing use Rating column for L1/L2 error
            //The values may not make much sense if regression value is not an actual label value
            if (TestSets != null)
            {
                for (int t = 0; t < TestSets.Length; ++t)
                {
                    bool[] labels = GetClassificationLabelsFromRatings(TestSets[t]).ToArray();
                    Tests.Add(new BinaryClassificationTest(ConstructScoreTracker(TestSets[t]), labels, _sigmoidParameter));
                }
            }

            if (FastTreeTrainerOptions.EnablePruning && ValidSet != null)
            {
                if (!FastTreeTrainerOptions.UseTolerantPruning)
                {
                    //use simple early stopping condition
                    PruningTest = new TestHistory(ValidTest, 0);
                }
                else
                {
                    //use tollerant stopping condition
                    PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold);
                }
            }
        }