public void ClassificationModelSelectingEnsembleLearner_Learn_Start_With_3_Models()
        {
            var learners = new IIndexedLearner <ProbabilityPrediction>[]
            {
                new ClassificationDecisionTreeLearner(2),
                new ClassificationDecisionTreeLearner(5),
                new ClassificationDecisionTreeLearner(7),
                new ClassificationDecisionTreeLearner(9),
                new ClassificationDecisionTreeLearner(11),
                new ClassificationDecisionTreeLearner(21),
                new ClassificationDecisionTreeLearner(23),
                new ClassificationDecisionTreeLearner(1),
                new ClassificationDecisionTreeLearner(14),
                new ClassificationDecisionTreeLearner(17),
                new ClassificationDecisionTreeLearner(19),
                new ClassificationDecisionTreeLearner(33)
            };

            var metric            = new LogLossClassificationProbabilityMetric();
            var ensembleStrategy  = new MeanProbabilityClassificationEnsembleStrategy();
            var ensembleSelection = new ForwardSearchClassificationEnsembleSelection(metric, ensembleStrategy, 5, 3, true);

            var sut = new ClassificationModelSelectingEnsembleLearner(learners, new RandomCrossValidation <ProbabilityPrediction>(5, 23),
                                                                      ensembleStrategy, ensembleSelection);

            var(observations, targets) = DataSetUtilities.LoadGlassDataSet();

            var model       = sut.Learn(observations, targets);
            var predictions = model.PredictProbability(observations);

            var actual = metric.Error(targets, predictions);

            Assert.AreEqual(0.55183985816428427, actual, 0.0001);
        }
Ejemplo n.º 2
0
        public void ClassificationBackwardEliminationModelSelectingEnsembleLearner_Learn_Indexed()
        {
            var learners = new IIndexedLearner <ProbabilityPrediction>[]
            {
                new ClassificationDecisionTreeLearner(2),
                new ClassificationDecisionTreeLearner(5),
                new ClassificationDecisionTreeLearner(7),
                new ClassificationDecisionTreeLearner(9),
                new ClassificationDecisionTreeLearner(11),
                new ClassificationDecisionTreeLearner(21),
                new ClassificationDecisionTreeLearner(23),
                new ClassificationDecisionTreeLearner(1),
                new ClassificationDecisionTreeLearner(14),
                new ClassificationDecisionTreeLearner(17),
                new ClassificationDecisionTreeLearner(19),
                new ClassificationDecisionTreeLearner(33)
            };

            var metric           = new LogLossClassificationProbabilityMetric();
            var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy();

            var sut = new ClassificationBackwardEliminationModelSelectingEnsembleLearner(learners, 5,
                                                                                         new RandomCrossValidation <ProbabilityPrediction>(5, 23), ensembleStrategy, metric);

            var(observations, targets) = DataSetUtilities.LoadGlassDataSet();

            var indices = Enumerable.Range(0, 25).ToArray();

            var model       = sut.Learn(observations, targets, indices);
            var predictions = model.PredictProbability(observations);

            var actual = metric.Error(targets, predictions);

            Assert.AreEqual(2.3682546920482164, actual, 0.0001);
        }
        public void ClassificationForwardSearchModelSelectingEnsembleLearner_Learn_Start_With_3_Models()
        {
            var learners = new IIndexedLearner <ProbabilityPrediction>[]
            {
                new ClassificationDecisionTreeLearner(2),
                new ClassificationDecisionTreeLearner(5),
                new ClassificationDecisionTreeLearner(7),
                new ClassificationDecisionTreeLearner(9),
                new ClassificationDecisionTreeLearner(11),
                new ClassificationDecisionTreeLearner(21),
                new ClassificationDecisionTreeLearner(23),
                new ClassificationDecisionTreeLearner(1),
                new ClassificationDecisionTreeLearner(14),
                new ClassificationDecisionTreeLearner(17),
                new ClassificationDecisionTreeLearner(19),
                new ClassificationDecisionTreeLearner(33)
            };

            var metric           = new LogLossClassificationProbabilityMetric();
            var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy();

            var sut = new ClassificationForwardSearchModelSelectingEnsembleLearner(learners, 5,
                                                                                   new StratifiedCrossValidation <ProbabilityPrediction>(5, 23), ensembleStrategy, metric, 3, true);

            var parser       = new CsvParser(() => new StringReader(Resources.Glass));
            var observations = parser.EnumerateRows(v => v != "Target").ToF64Matrix();
            var targets      = parser.EnumerateRows("Target").ToF64Vector();

            var model       = sut.Learn(observations, targets);
            var predictions = model.PredictProbability(observations);

            var actual = metric.Error(targets, predictions);

            Assert.AreEqual(0.54434276244488244, actual, 0.0001);
        }
Ejemplo n.º 4
0
        public void ClassificationRandomModelSelectingEnsembleLearner_Learn_Without_Replacement()
        {
            var learners = new IIndexedLearner <ProbabilityPrediction>[]
            {
                new ClassificationDecisionTreeLearner(2),
                new ClassificationDecisionTreeLearner(5),
                new ClassificationDecisionTreeLearner(7),
                new ClassificationDecisionTreeLearner(9),
                new ClassificationDecisionTreeLearner(11),
                new ClassificationDecisionTreeLearner(21),
                new ClassificationDecisionTreeLearner(23),
                new ClassificationDecisionTreeLearner(1),
                new ClassificationDecisionTreeLearner(14),
                new ClassificationDecisionTreeLearner(17),
                new ClassificationDecisionTreeLearner(19),
                new ClassificationDecisionTreeLearner(33)
            };

            var metric           = new LogLossClassificationProbabilityMetric();
            var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy();

            var sut = new ClassificationRandomModelSelectingEnsembleLearner(learners, 5,
                                                                            new StratifiedCrossValidation <ProbabilityPrediction>(5, 23), ensembleStrategy, metric, 1, false);

            var(observations, targets) = DataSetUtilities.LoadGlassDataSet();

            var model       = sut.Learn(observations, targets);
            var predictions = model.PredictProbability(observations);

            var actual = metric.Error(targets, predictions);

            Assert.AreEqual(0.5805783545646459, actual, 0.0001);
        }
        public void ClassificationModelSelectingEnsembleLearner_Constructor_EnsembleSelection_Null()
        {
            var learners         = new IIndexedLearner <ProbabilityPrediction> [4];
            var metric           = new LogLossClassificationProbabilityMetric();
            var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy();
            var crossValidation  = new RandomCrossValidation <ProbabilityPrediction>(5);

            var sut = new ClassificationModelSelectingEnsembleLearner(learners, crossValidation, ensembleStrategy, null);
        }
        public void ClassificationModelSelectingEnsembleLearner_Constructor_CrossValidation_Null()
        {
            var learners          = new IIndexedLearner <ProbabilityPrediction> [4];
            var metric            = new LogLossClassificationProbabilityMetric();
            var ensembleStrategy  = new MeanProbabilityClassificationEnsembleStrategy();
            var ensembleSelection = new ForwardSearchClassificationEnsembleSelection(metric, ensembleStrategy, 5, 1, true);

            var sut = new ClassificationModelSelectingEnsembleLearner(learners, null, ensembleStrategy, ensembleSelection);
        }
        public void ClassificationModelSelectingEnsembleLearner_Constructor_Learners_Null()
        {
            var metric            = new LogLossClassificationProbabilityMetric();
            var ensembleStrategy  = new MeanProbabilityClassificationEnsembleStrategy();
            var ensembleSelection = new ForwardSearchClassificationEnsembleSelection(metric, ensembleStrategy, 5, 1, true);
            var crossValidation   = new RandomCrossValidation <ProbabilityPrediction>(5);

            var sut = new ClassificationModelSelectingEnsembleLearner(null, crossValidation, ensembleStrategy, ensembleSelection);
        }
        public void ClassificationModelSelectingEnsembleLearner_Learn_Indexed()
        {
            var learners = new IIndexedLearner <ProbabilityPrediction>[]
            {
                new ClassificationDecisionTreeLearner(2),
                new ClassificationDecisionTreeLearner(5),
                new ClassificationDecisionTreeLearner(7),
                new ClassificationDecisionTreeLearner(9),
                new ClassificationDecisionTreeLearner(11),
                new ClassificationDecisionTreeLearner(21),
                new ClassificationDecisionTreeLearner(23),
                new ClassificationDecisionTreeLearner(1),
                new ClassificationDecisionTreeLearner(14),
                new ClassificationDecisionTreeLearner(17),
                new ClassificationDecisionTreeLearner(19),
                new ClassificationDecisionTreeLearner(33)
            };

            var metric            = new LogLossClassificationProbabilityMetric();
            var ensembleStrategy  = new MeanProbabilityClassificationEnsembleStrategy();
            var ensembleSelection = new ForwardSearchClassificationEnsembleSelection(metric, ensembleStrategy, 5, 1, true);

            var sut = new ClassificationModelSelectingEnsembleLearner(learners, new RandomCrossValidation <ProbabilityPrediction>(5, 23),
                                                                      ensembleStrategy, ensembleSelection);

            var parser       = new CsvParser(() => new StringReader(Resources.Glass));
            var observations = parser.EnumerateRows(v => v != "Target").ToF64Matrix();
            var targets      = parser.EnumerateRows("Target").ToF64Vector();
            var indices      = Enumerable.Range(0, 25).ToArray();

            var model       = sut.Learn(observations, targets, indices);
            var predictions = model.PredictProbability(observations);

            var actual = metric.Error(targets, predictions);

            Assert.AreEqual(2.3682546920482164, actual, 0.0001);
        }
        public void MeanProbabilityClassificationEnsembleStrategy_Combine()
        {
            var values = new ProbabilityPrediction[]
            {
                new ProbabilityPrediction(1.0, new Dictionary <double, double> {
                    { 0.0, 0.3 }, { 1.0, 0.88 }
                }),
                new ProbabilityPrediction(0.0, new Dictionary <double, double> {
                    { 0.0, 0.66 }, { 1.0, 0.33 }
                }),
                new ProbabilityPrediction(1.0, new Dictionary <double, double> {
                    { 0.0, 0.01 }, { 1.0, 0.99 }
                }),
            };

            var sut    = new MeanProbabilityClassificationEnsembleStrategy();
            var actual = sut.Combine(values);

            var expected = new ProbabilityPrediction(1.0, new Dictionary <double, double> {
                { 0.0, 0.323333333333333 }, { 1.0, 0.733333333333333 }
            });

            Assert.AreEqual(expected, actual);
        }