public void Regression_NumericAttrsAndOutcomesOnly_RegularizedRegression() { // Given var randomizer = new Random(3); var splitter = new CrossValidator <double>(randomizer); var testData = TestDataBuilder.ReadHousingDataNormalizedAttrs(); var predictor = new DecisionTreePredictor <double>(); var numericTreeBuilder = new BinaryDecisionTreeModelBuilder( new VarianceBasedSplitQualityChecker(), new BestSplitSelectorForNumericValues(new BinaryNumericDataSplitter()), new RegressionAndModelDecisionTreeLeafBuilder(new RegularizedLinearRegressionModelBuilder(0.005))); // When var accuracies = splitter.CrossValidate( modelBuilder: numericTreeBuilder, modelBuilderParams: modelBuilderParams, predictor: predictor, qualityMeasure: new GoodnessOfFitQualityMeasure(), dataFrame: testData, dependentFeatureName: "MEDV", percetnagOfTrainData: 0.7, folds: 15); // Then var averegeRsquared = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averegeRsquared >= 0.6); }
public void Mushroom_BinarySplit() { // Given var randomizer = new Random(3); var splitter = new CrossValidator <string>(randomizer); var testData = TestDataBuilder.ReadMushroomDataWithCategoricalAttributes(); var predictor = new DecisionTreePredictor <string>(); // When var accuracies = splitter.CrossValidate( modelBuilder: binaryTreeBuilder, modelBuilderParams: modelBuilderParams, predictor: predictor, qualityMeasure: new ConfusionMatrixBuilder <string>(), dataFrame: testData, dependentFeatureName: "type", percetnagOfTrainData: 0.7, folds: 2); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.99); }
public void Mushroom_MultiSplit_StatisticalSignificanceHeuristic() { // Given var randomizer = new Random(3); var splitter = new CrossValidator <string>(randomizer); var testData = TestDataBuilder.ReadMushroomDataWithCategoricalAttributes(); var predictor = new DecisionTreePredictor <string>(); // When var accuracies = splitter.CrossValidate( modelBuilder: this.BuildCustomModelBuilder(statisticalSignificanceChecker: new ChiSquareStatisticalSignificanceChecker()), modelBuilderParams: new DecisionTreeModelBuilderParams(false, true), predictor: predictor, qualityMeasure: new ConfusionMatrixBuilder <string>(), dataFrame: testData, dependentFeatureName: "type", percetnagOfTrainData: 0.7, folds: 2); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.99); }
public void DiscreteClassification_CategoricalFeatures_BinarySplits_ConvressVotingData_StatisticalSignificanceTest_CrossValidation() { // Given var randomizer = new Random(3); var splitter = new CrossValidator <string>(randomizer); var testData = TestDataBuilder.ReadCongressData() as DataFrame; var predictor = new DecisionTreePredictor <string>(); // When var accuracies = splitter.CrossValidate( modelBuilder: this.BuildCustomModelBuilder(true, statisticalSignificanceChecker: new ChiSquareStatisticalSignificanceChecker()), modelBuilderParams: new DecisionTreeModelBuilderParams(false, true), predictor: predictor, qualityMeasure: new ConfusionMatrixBuilder <string>(), dataFrame: testData, dependentFeatureName: "party", percetnagOfTrainData: 0.7, folds: 10); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.9); }
public void DiscreteClassification_CategoricalFeatures_MultiValuesSplits_CongressVotingData_CrossValidation() { // Given var randomizer = new Random(); var splitter = new CrossValidator <string>(randomizer); var testData = TestDataBuilder.ReadCongressData() as DataFrame; var predictor = new DecisionTreePredictor <string>(); // When var accuracies = splitter.CrossValidate(modelBuilder: this.multiValueTreeBuilder, modelBuilderParams: modelBuilderParams, predictor: predictor, qualityMeasure: new ConfusionMatrixBuilder <string>(), dataFrame: testData, dependentFeatureName: "party", percetnagOfTrainData: 0.7, folds: 10); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.9); }
public void DiscreteClassification_CategoricalFeatures_BinarySplits_ConvressVotingData_CrossValidation() { // Given var randomizer = new Random(3); var splitter = new CrossValidator<string>(randomizer); var testData = TestDataBuilder.ReadCongressData() as DataFrame; var predictor = new DecisionTreePredictor<string>(); // When var accuracies = splitter.CrossValidate( modelBuilder: binaryTreeBuilder, modelBuilderParams: modelBuilderParams, predictor: predictor, qualityMeasure: new ConfusionMatrixBuilder<string>(), dataFrame: testData, dependentFeatureName: "party", percetnagOfTrainData: 0.7, folds: 10); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.9); }
public void DiscreteClassification_NumericFeatures_MultiValuesSplits_AdultCensusData_CrossValidation() { // Given var splitter = new CrossValidator <object>(); var testData = TestDataBuilder.ReadAdultCensusDataFrame(); var predictor = new DecisionTreePredictor <object>(); // When var accuracies = splitter.CrossValidate( multiValueTreeBuilderWithBetterNumercValsHandler, modelBuilderParams, predictor, new ConfusionMatrixBuilder <object>(), testData, "income", 0.7, 5); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.8); }
public void DiscreteClassification_NumericFeatures_BinarySplits_IrisData_CrossValidation() { // Given var randomizer = new Random(); var splitter = new CrossValidator <object>(); var testData = TestDataBuilder.ReadIrisData(); var predictor = new DecisionTreePredictor <object>(); // When var accuracies = splitter.CrossValidate( binaryTreeBuilder, modelBuilderParams, predictor, new ConfusionMatrixBuilder <object>(), testData, "iris_class", 0.7, 10); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.9); }
public void Regression_NumericAttrsAndOutcomesOnly_RegularizedRegression() { // Given var randomizer = new Random(3); var splitter = new CrossValidator<double>(randomizer); var testData = TestDataBuilder.ReadHousingDataNormalizedAttrs(); var predictor = new DecisionTreePredictor<double>(); var numericTreeBuilder = new BinaryDecisionTreeModelBuilder( new VarianceBasedSplitQualityChecker(), new BestSplitSelectorForNumericValues(new BinaryNumericDataSplitter()), new RegressionAndModelDecisionTreeLeafBuilder(new RegularizedLinearRegressionModelBuilder(0.005))); // When var accuracies = splitter.CrossValidate( modelBuilder: numericTreeBuilder, modelBuilderParams: modelBuilderParams, predictor: predictor, qualityMeasure: new GoodnessOfFitQualityMeasure(), dataFrame: testData, dependentFeatureName: "MEDV", percetnagOfTrainData: 0.7, folds: 15); // Then var averegeRsquared = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averegeRsquared >= 0.6); }
public void Mushroom_MultiSplit_StatisticalSignificanceHeuristic() { // Given var randomizer = new Random(3); var splitter = new CrossValidator<string>(randomizer); var testData = TestDataBuilder.ReadMushroomDataWithCategoricalAttributes(); var predictor = new DecisionTreePredictor<string>(); // When var accuracies = splitter.CrossValidate( modelBuilder: this.BuildCustomModelBuilder(statisticalSignificanceChecker: new ChiSquareStatisticalSignificanceChecker()), modelBuilderParams: new DecisionTreeModelBuilderParams(false, true), predictor: predictor, qualityMeasure: new ConfusionMatrixBuilder<string>(), dataFrame: testData, dependentFeatureName: "type", percetnagOfTrainData: 0.7, folds: 2); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.99); }
public void Mushroom_MultiSplit() { // Given var randomizer = new Random(3); var splitter = new CrossValidator<string>(randomizer); var testData = TestDataBuilder.ReadMushroomDataWithCategoricalAttributes(); var predictor = new DecisionTreePredictor<string>(); // When var accuracies = splitter.CrossValidate( modelBuilder: multiValueTreeBuilder, modelBuilderParams: modelBuilderParams, predictor: predictor, qualityMeasure: new ConfusionMatrixBuilder<string>(), dataFrame: testData, dependentFeatureName: "type", percetnagOfTrainData: 0.7, folds: 2); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.99); }
public void DiscreteClassification_NumericFeatures_MultiValuesSplits_AdultCensusData_CrossValidation() { // Given var splitter = new CrossValidator<object>(); var testData = TestDataBuilder.ReadAdultCensusDataFrame(); var predictor = new DecisionTreePredictor<object>(); // When var accuracies = splitter.CrossValidate( multiValueTreeBuilderWithBetterNumercValsHandler, modelBuilderParams, predictor, new ConfusionMatrixBuilder<object>(), testData, "income", 0.7, 5); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.8); }
public void DiscreteClassification_NumericFeatures_BinarySplits_IrisData_CrossValidation() { // Given var randomizer = new Random(); var splitter = new CrossValidator<object>(); var testData = TestDataBuilder.ReadIrisData(); var predictor = new DecisionTreePredictor<object>(); // When var accuracies = splitter.CrossValidate( binaryTreeBuilder, modelBuilderParams, predictor, new ConfusionMatrixBuilder<object>(), testData, "iris_class", 0.7, 10); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.9); }
public void DiscreteClassification_CategoricalFeatures_MultiValuesSplits_CongressVotingData_StatisticalSignificanceHeuristic_CrossValidation() { // Given var randomizer = new Random(); var splitter = new CrossValidator<string>(randomizer); var testData = TestDataBuilder.ReadCongressData() as DataFrame; var predictor = new DecisionTreePredictor<string>(); // When var accuracies = splitter.CrossValidate( modelBuilder: this.BuildCustomModelBuilder(true, statisticalSignificanceChecker: new ChiSquareStatisticalSignificanceChecker(0.05)), modelBuilderParams: new DecisionTreeModelBuilderParams(false, true), predictor: predictor, qualityMeasure: new ConfusionMatrixBuilder<string>(), dataFrame: testData, dependentFeatureName: "party", percetnagOfTrainData: 0.7, folds: 10); // Then var averageAccuracy = accuracies.Select(report => report.Accuracy).Average(); Assert.IsTrue(averageAccuracy >= 0.9); }