예제 #1
0
        protected override void InitializeTests()
        {
            //Always compute training L1/L2 errors
            TrainTest = new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter);
            Tests.Add(TrainTest);

            if (ValidSet != null)
            {
                ValidTest = new BinaryClassificationTest(ConstructScoreTracker(ValidSet),
                                                         GetClassificationLabelsFromRatings(ValidSet).ToArray(), _sigmoidParameter);
                Tests.Add(ValidTest);
            }

            //If external label is missing use Rating column for L1/L2 error
            //The values may not make much sense if regression value is not an actual label value
            if (TestSets != null)
            {
                for (int t = 0; t < TestSets.Length; ++t)
                {
                    bool[] labels = GetClassificationLabelsFromRatings(TestSets[t]).ToArray();
                    Tests.Add(new BinaryClassificationTest(ConstructScoreTracker(TestSets[t]), labels, _sigmoidParameter));
                }
            }

            if (Args.EnablePruning && ValidSet != null)
            {
                if (!Args.UseTolerantPruning)
                {
                    //use simple early stopping condition
                    PruningTest = new TestHistory(ValidTest, 0);
                }
                else
                {
                    //use tollerant stopping condition
                    PruningTest = new TestWindowWithTolerance(ValidTest, 0, Args.PruningWindowSize, Args.PruningThreshold);
                }
            }
        }
예제 #2
0
        protected virtual void AddFullNDCGTests()
        {
            Tests.Add(new NDCGTest(ConstructScoreTracker(TrainSet), TrainSet.Ratings, _args.sortingAlgorithm));
            if (ValidSet != null)
            {
                Test test = new NDCGTest(ConstructScoreTracker(ValidSet), ValidSet.Ratings, _args.sortingAlgorithm);
                Tests.Add(test);
            }

            if (TestSets != null)
            {
                for (int t = 0; t < TestSets.Length; ++t)
                {
                    Test test = new NDCGTest(ConstructScoreTracker(TestSets[t]), TestSets[t].Ratings, _args.sortingAlgorithm);

                    if (t == 0)
                    {
                        _firstTestSetHistory = new TestHistory(test, 0);
                    }

                    Tests.Add(test);
                }
            }
        }