protected override void InitializeTests() { // Initialize regression tests. if (Args.TestFrequency != int.MaxValue) { AddFullRegressionTests(); } if (Args.PrintTestGraph) { // If FirstTestHistory is null (which means the tests were not intialized due to /tf==infinity), // we need initialize first set for graph printing. // Adding to a tests would result in printing the results after final iteration. if (_firstTestSetHistory == null) { var firstTestSetTest = new RegressionTest(ConstructScoreTracker(TestSets[0])); _firstTestSetHistory = new TestHistory(firstTestSetTest, 0); } } if (Args.PrintTrainValidGraph && _trainRegressionTest == null) { Test trainRegressionTest = new RegressionTest(ConstructScoreTracker(TrainSet)); _trainRegressionTest = trainRegressionTest; } if (Args.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0) { _testRegressionTest = new RegressionTest(ConstructScoreTracker(TestSets[0])); } // Add early stopping if appropriate. TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), Args.EarlyStoppingMetrics); if (ValidSet != null) { ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), Args.EarlyStoppingMetrics); } if (Args.EnablePruning && ValidTest != null) { if (Args.UseTolerantPruning) // Use simple early stopping condition. { PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold); } else { PruningTest = new TestHistory(ValidTest, 0); } } }
private protected override void InitializeTests() { //Always compute training L1/L2 errors TrainTest = new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter); Tests.Add(TrainTest); if (ValidSet != null) { ValidTest = new BinaryClassificationTest(ConstructScoreTracker(ValidSet), GetClassificationLabelsFromRatings(ValidSet).ToArray(), _sigmoidParameter); Tests.Add(ValidTest); } //If external label is missing use Rating column for L1/L2 error //The values may not make much sense if regression value is not an actual label value if (TestSets != null) { for (int t = 0; t < TestSets.Length; ++t) { bool[] labels = GetClassificationLabelsFromRatings(TestSets[t]).ToArray(); Tests.Add(new BinaryClassificationTest(ConstructScoreTracker(TestSets[t]), labels, _sigmoidParameter)); } } if (FastTreeTrainerOptions.EnablePruning && ValidSet != null) { if (!FastTreeTrainerOptions.UseTolerantPruning) { //use simple early stopping condition PruningTest = new TestHistory(ValidTest, 0); } else { //use tollerant stopping condition PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold); } } }