public void ClassificationModelSelectingEnsembleLearner_Learn_Start_With_3_Models() { var learners = new IIndexedLearner <ProbabilityPrediction>[] { new ClassificationDecisionTreeLearner(2), new ClassificationDecisionTreeLearner(5), new ClassificationDecisionTreeLearner(7), new ClassificationDecisionTreeLearner(9), new ClassificationDecisionTreeLearner(11), new ClassificationDecisionTreeLearner(21), new ClassificationDecisionTreeLearner(23), new ClassificationDecisionTreeLearner(1), new ClassificationDecisionTreeLearner(14), new ClassificationDecisionTreeLearner(17), new ClassificationDecisionTreeLearner(19), new ClassificationDecisionTreeLearner(33) }; var metric = new LogLossClassificationProbabilityMetric(); var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy(); var ensembleSelection = new ForwardSearchClassificationEnsembleSelection(metric, ensembleStrategy, 5, 3, true); var sut = new ClassificationModelSelectingEnsembleLearner(learners, new RandomCrossValidation <ProbabilityPrediction>(5, 23), ensembleStrategy, ensembleSelection); var(observations, targets) = DataSetUtilities.LoadGlassDataSet(); var model = sut.Learn(observations, targets); var predictions = model.PredictProbability(observations); var actual = metric.Error(targets, predictions); Assert.AreEqual(0.55183985816428427, actual, 0.0001); }
public void ClassificationBackwardEliminationModelSelectingEnsembleLearner_Learn_Indexed() { var learners = new IIndexedLearner <ProbabilityPrediction>[] { new ClassificationDecisionTreeLearner(2), new ClassificationDecisionTreeLearner(5), new ClassificationDecisionTreeLearner(7), new ClassificationDecisionTreeLearner(9), new ClassificationDecisionTreeLearner(11), new ClassificationDecisionTreeLearner(21), new ClassificationDecisionTreeLearner(23), new ClassificationDecisionTreeLearner(1), new ClassificationDecisionTreeLearner(14), new ClassificationDecisionTreeLearner(17), new ClassificationDecisionTreeLearner(19), new ClassificationDecisionTreeLearner(33) }; var metric = new LogLossClassificationProbabilityMetric(); var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy(); var sut = new ClassificationBackwardEliminationModelSelectingEnsembleLearner(learners, 5, new RandomCrossValidation <ProbabilityPrediction>(5, 23), ensembleStrategy, metric); var(observations, targets) = DataSetUtilities.LoadGlassDataSet(); var indices = Enumerable.Range(0, 25).ToArray(); var model = sut.Learn(observations, targets, indices); var predictions = model.PredictProbability(observations); var actual = metric.Error(targets, predictions); Assert.AreEqual(2.3682546920482164, actual, 0.0001); }
public void ClassificationForwardSearchModelSelectingEnsembleLearner_Learn_Start_With_3_Models() { var learners = new IIndexedLearner <ProbabilityPrediction>[] { new ClassificationDecisionTreeLearner(2), new ClassificationDecisionTreeLearner(5), new ClassificationDecisionTreeLearner(7), new ClassificationDecisionTreeLearner(9), new ClassificationDecisionTreeLearner(11), new ClassificationDecisionTreeLearner(21), new ClassificationDecisionTreeLearner(23), new ClassificationDecisionTreeLearner(1), new ClassificationDecisionTreeLearner(14), new ClassificationDecisionTreeLearner(17), new ClassificationDecisionTreeLearner(19), new ClassificationDecisionTreeLearner(33) }; var metric = new LogLossClassificationProbabilityMetric(); var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy(); var sut = new ClassificationForwardSearchModelSelectingEnsembleLearner(learners, 5, new StratifiedCrossValidation <ProbabilityPrediction>(5, 23), ensembleStrategy, metric, 3, true); var parser = new CsvParser(() => new StringReader(Resources.Glass)); var observations = parser.EnumerateRows(v => v != "Target").ToF64Matrix(); var targets = parser.EnumerateRows("Target").ToF64Vector(); var model = sut.Learn(observations, targets); var predictions = model.PredictProbability(observations); var actual = metric.Error(targets, predictions); Assert.AreEqual(0.54434276244488244, actual, 0.0001); }
public void ClassificationRandomModelSelectingEnsembleLearner_Learn_Without_Replacement() { var learners = new IIndexedLearner <ProbabilityPrediction>[] { new ClassificationDecisionTreeLearner(2), new ClassificationDecisionTreeLearner(5), new ClassificationDecisionTreeLearner(7), new ClassificationDecisionTreeLearner(9), new ClassificationDecisionTreeLearner(11), new ClassificationDecisionTreeLearner(21), new ClassificationDecisionTreeLearner(23), new ClassificationDecisionTreeLearner(1), new ClassificationDecisionTreeLearner(14), new ClassificationDecisionTreeLearner(17), new ClassificationDecisionTreeLearner(19), new ClassificationDecisionTreeLearner(33) }; var metric = new LogLossClassificationProbabilityMetric(); var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy(); var sut = new ClassificationRandomModelSelectingEnsembleLearner(learners, 5, new StratifiedCrossValidation <ProbabilityPrediction>(5, 23), ensembleStrategy, metric, 1, false); var(observations, targets) = DataSetUtilities.LoadGlassDataSet(); var model = sut.Learn(observations, targets); var predictions = model.PredictProbability(observations); var actual = metric.Error(targets, predictions); Assert.AreEqual(0.5805783545646459, actual, 0.0001); }
public void ClassificationModelSelectingEnsembleLearner_Constructor_EnsembleSelection_Null() { var learners = new IIndexedLearner <ProbabilityPrediction> [4]; var metric = new LogLossClassificationProbabilityMetric(); var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy(); var crossValidation = new RandomCrossValidation <ProbabilityPrediction>(5); var sut = new ClassificationModelSelectingEnsembleLearner(learners, crossValidation, ensembleStrategy, null); }
public void ClassificationModelSelectingEnsembleLearner_Constructor_CrossValidation_Null() { var learners = new IIndexedLearner <ProbabilityPrediction> [4]; var metric = new LogLossClassificationProbabilityMetric(); var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy(); var ensembleSelection = new ForwardSearchClassificationEnsembleSelection(metric, ensembleStrategy, 5, 1, true); var sut = new ClassificationModelSelectingEnsembleLearner(learners, null, ensembleStrategy, ensembleSelection); }
public void ClassificationModelSelectingEnsembleLearner_Constructor_Learners_Null() { var metric = new LogLossClassificationProbabilityMetric(); var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy(); var ensembleSelection = new ForwardSearchClassificationEnsembleSelection(metric, ensembleStrategy, 5, 1, true); var crossValidation = new RandomCrossValidation <ProbabilityPrediction>(5); var sut = new ClassificationModelSelectingEnsembleLearner(null, crossValidation, ensembleStrategy, ensembleSelection); }
public void ClassificationModelSelectingEnsembleLearner_Learn_Indexed() { var learners = new IIndexedLearner <ProbabilityPrediction>[] { new ClassificationDecisionTreeLearner(2), new ClassificationDecisionTreeLearner(5), new ClassificationDecisionTreeLearner(7), new ClassificationDecisionTreeLearner(9), new ClassificationDecisionTreeLearner(11), new ClassificationDecisionTreeLearner(21), new ClassificationDecisionTreeLearner(23), new ClassificationDecisionTreeLearner(1), new ClassificationDecisionTreeLearner(14), new ClassificationDecisionTreeLearner(17), new ClassificationDecisionTreeLearner(19), new ClassificationDecisionTreeLearner(33) }; var metric = new LogLossClassificationProbabilityMetric(); var ensembleStrategy = new MeanProbabilityClassificationEnsembleStrategy(); var ensembleSelection = new ForwardSearchClassificationEnsembleSelection(metric, ensembleStrategy, 5, 1, true); var sut = new ClassificationModelSelectingEnsembleLearner(learners, new RandomCrossValidation <ProbabilityPrediction>(5, 23), ensembleStrategy, ensembleSelection); var parser = new CsvParser(() => new StringReader(Resources.Glass)); var observations = parser.EnumerateRows(v => v != "Target").ToF64Matrix(); var targets = parser.EnumerateRows("Target").ToF64Vector(); var indices = Enumerable.Range(0, 25).ToArray(); var model = sut.Learn(observations, targets, indices); var predictions = model.PredictProbability(observations); var actual = metric.Error(targets, predictions); Assert.AreEqual(2.3682546920482164, actual, 0.0001); }
public void MeanProbabilityClassificationEnsembleStrategy_Combine() { var values = new ProbabilityPrediction[] { new ProbabilityPrediction(1.0, new Dictionary <double, double> { { 0.0, 0.3 }, { 1.0, 0.88 } }), new ProbabilityPrediction(0.0, new Dictionary <double, double> { { 0.0, 0.66 }, { 1.0, 0.33 } }), new ProbabilityPrediction(1.0, new Dictionary <double, double> { { 0.0, 0.01 }, { 1.0, 0.99 } }), }; var sut = new MeanProbabilityClassificationEnsembleStrategy(); var actual = sut.Combine(values); var expected = new ProbabilityPrediction(1.0, new Dictionary <double, double> { { 0.0, 0.323333333333333 }, { 1.0, 0.733333333333333 } }); Assert.AreEqual(expected, actual); }