Example #1
0
        public void RandomLearningCurvesCalculator_Calculate()
        {
            var sut = new RandomShuffleLearningCurvesCalculator <double>(new MeanSquaredErrorRegressionMetric(),
                                                                         new double[] { 0.2, 0.8 }, 0.8, 42, 5);

            var(observations, targets) = DataSetUtilities.LoadDecisionTreeDataSet();

            var actual = sut.Calculate(new RegressionDecisionTreeLearner(),
                                       observations, targets);

            var expected = new List <LearningCurvePoint>()
            {
                new LearningCurvePoint(32, 0, 0.141565953928265),
                new LearningCurvePoint(128, 0.0, 0.068970597423950036)
            };

            CollectionAssert.AreEqual(expected, actual);
        }
Example #2
0
        public void LearningCurves_Calculate_ProbabilityPrediction()
        {
            #region Read data

            // Use StreamReader(filepath) when running from filesystem
            var parser     = new CsvParser(() => new StringReader(Resources.winequality_white));
            var targetName = "quality";

            // read feature matrix
            var observations = parser.EnumerateRows(c => c != targetName)
                               .ToF64Matrix();

            // read classification targets and convert to binary problem (low quality/high quality).
            var targets = parser.EnumerateRows(targetName)
                          .ToF64Vector().Select(t => t < 5 ? 0.0 : 1.0).ToArray();

            #endregion

            // metric for measuring model error
            var metric = new LogLossClassificationProbabilityMetric();

            // creates cross validator, observations are shuffled randomly
            var learningCurveCalculator = new RandomShuffleLearningCurvesCalculator <ProbabilityPrediction>(metric,
                                                                                                            samplePercentages: new double[] { 0.05, 0.1, 0.2, 0.4, 0.8, 1.0 },
                                                                                                            trainingPercentage: 0.7, numberOfShufflesPrSample: 5);

            // create learner
            var learner = new ClassificationDecisionTreeLearner(maximumTreeDepth: 5);

            // calculate learning curve
            var learningCurve = learningCurveCalculator.Calculate(learner, observations, targets);

            // write to csv
            var writer = new StringWriter();
            learningCurve.Write(() => writer);

            // trace result
            // Plotting the learning curves will help determine if the model has high bias or high variance.
            // This information can be used to determine what to try next in order to improve the model.
            Trace.WriteLine(writer.ToString());

            // alternatively, write to file
            //learningCurve.Write(() => new StreamWriter(filePath));
        }
Example #3
0
        public void RandomLearningCurvesCalculator_Calculate()
        {
            var sut = new RandomShuffleLearningCurvesCalculator <double>(new MeanSquaredErrorRegressionMetric(),
                                                                         new double[] { 0.2, 0.8 }, 0.8, 42, 5);

            var targetName   = "T";
            var parser       = new CsvParser(() => new StringReader(Resources.DecisionTreeData));
            var observations = parser.EnumerateRows(v => !v.Contains(targetName)).ToF64Matrix();
            var targets      = parser.EnumerateRows(targetName).ToF64Vector();

            var actual = sut.Calculate(new RegressionDecisionTreeLearner(),
                                       observations, targets);

            var expected = new List <LearningCurvePoint>()
            {
                new LearningCurvePoint(32, 0, 0.141565953928265),
                new LearningCurvePoint(128, 0.0, 0.068970597423950036)
            };

            CollectionAssert.AreEqual(expected, actual);
        }