Example #1
0
        public void Regression_NumericAttrsAndOutcomesOnly_RegularizedRegression()
        {
            // Given
            var randomizer = new Random(3);
            var splitter   = new CrossValidator <double>(randomizer);
            var testData   = TestDataBuilder.ReadHousingDataNormalizedAttrs();

            var predictor = new DecisionTreePredictor <double>();

            var numericTreeBuilder = new BinaryDecisionTreeModelBuilder(
                new VarianceBasedSplitQualityChecker(),
                new BestSplitSelectorForNumericValues(new BinaryNumericDataSplitter()),
                new RegressionAndModelDecisionTreeLeafBuilder(new RegularizedLinearRegressionModelBuilder(0.005)));

            // When
            var accuracies = splitter.CrossValidate(
                modelBuilder: numericTreeBuilder,
                modelBuilderParams: modelBuilderParams,
                predictor: predictor,
                qualityMeasure: new GoodnessOfFitQualityMeasure(),
                dataFrame: testData,
                dependentFeatureName: "MEDV",
                percetnagOfTrainData: 0.7,
                folds: 15);

            // Then
            var averegeRsquared = accuracies.Select(report => report.Accuracy).Average();

            Assert.IsTrue(averegeRsquared >= 0.6);
        }
Example #2
0
        public void Mushroom_BinarySplit()
        {
            // Given
            var randomizer = new Random(3);
            var splitter   = new CrossValidator <string>(randomizer);
            var testData   = TestDataBuilder.ReadMushroomDataWithCategoricalAttributes();

            var predictor = new DecisionTreePredictor <string>();

            // When
            var accuracies = splitter.CrossValidate(
                modelBuilder: binaryTreeBuilder,
                modelBuilderParams: modelBuilderParams,
                predictor: predictor,
                qualityMeasure: new ConfusionMatrixBuilder <string>(),
                dataFrame: testData,
                dependentFeatureName: "type",
                percetnagOfTrainData: 0.7,
                folds: 2);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();

            Assert.IsTrue(averageAccuracy >= 0.99);
        }
        public void RegularizedGradientDescent_ArtificialFunction()
        {
            // Given
            var splitter = new CrossValidator<double>();
            Func<IList<double>, double> scoreFunc = list => 0.3 + (0.5 * list[0]) + (-0.3 * list[1]) + (0.7 * list[2]);
            var allData =
                TestDataBuilder.BuildRandomAbstractNumericDataFrame(
                    scoreFunc,
                    featuresCount: 3,
                    min: 0,
                    max: 1,
                    rowCount: 1000);
            var subject = new RegularizedGradientDescentModelBuilder(0, 1);
            var regParams = new LinearRegressionParams(0.05);

            // When
            var accuracies = splitter.CrossValidate(
               modelBuilder: subject,
               modelBuilderParams: regParams,
               predictor: new LinearRegressionPredictor(),
               qualityMeasure: new GoodnessOfFitQualityMeasure(),
               dataFrame: allData,
               dependentFeatureName: "result",
               percetnagOfTrainData: 0.8,
               folds: 20);

            // Then
            Assert.IsTrue(accuracies.Select(acc => acc.Accuracy).Average() >= 0.9);
        }
Example #4
0
        public void DiscreteClassification_CategoricalFeatures_BinarySplits_ConvressVotingData_StatisticalSignificanceTest_CrossValidation()
        {
            // Given
            var randomizer = new Random(3);
            var splitter   = new CrossValidator <string>(randomizer);
            var testData   = TestDataBuilder.ReadCongressData() as DataFrame;

            var predictor = new DecisionTreePredictor <string>();

            // When
            var accuracies = splitter.CrossValidate(
                modelBuilder: this.BuildCustomModelBuilder(true, statisticalSignificanceChecker: new ChiSquareStatisticalSignificanceChecker()),
                modelBuilderParams: new DecisionTreeModelBuilderParams(false, true),
                predictor: predictor,
                qualityMeasure: new ConfusionMatrixBuilder <string>(),
                dataFrame: testData,
                dependentFeatureName: "party",
                percetnagOfTrainData: 0.7,
                folds: 10);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();

            Assert.IsTrue(averageAccuracy >= 0.9);
        }
Example #5
0
        static void Main(string[] args)
        {
            var dataFilePath = "Data/test_generated.data";

            var pipeline = new LearningPipeline()
            {
                new TextLoader(dataFilePath).CreateFrom <ReopenedIssueData>(),
                new TextFeaturizer(Columns.Environment, Columns.Environment),
                new TextFeaturizer(Columns.Type, Columns.Type),
                new TextFeaturizer(Columns.ProjectName, Columns.ProjectName),
                new TextFeaturizer(Columns.AsigneeEmail, Columns.AsigneeEmail),
                new TextFeaturizer(Columns.ReporterEmail, Columns.ReporterEmail),
                new ColumnConcatenator(
                    Columns.Features,
                    Columns.Environment,
                    Columns.Type,
                    Columns.CommentsCount,
                    Columns.CommentsLenght,
                    Columns.ReporterCommentsCount,
                    Columns.ProjectName,
                    Columns.AsigneeEmail,
                    Columns.ReporterEmail
                    ),
                new FastTreeBinaryClassifier()
            };

            //var predictionModel = pipeline.Train<ReopenedIssueData, ReopenedIssuePrediction>();

            var crossValidator = new CrossValidator()
            {
                // NumFolds = numOfFolds,
                Kind = MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer
            };
            var crossValidationResult = crossValidator.CrossValidate <ReopenedIssueData, ReopenedIssuePrediction>(pipeline);
        }
Example #6
0
        public void Mushroom_MultiSplit_StatisticalSignificanceHeuristic()
        {
            // Given
            var randomizer = new Random(3);
            var splitter   = new CrossValidator <string>(randomizer);
            var testData   = TestDataBuilder.ReadMushroomDataWithCategoricalAttributes();

            var predictor = new DecisionTreePredictor <string>();

            // When
            var accuracies = splitter.CrossValidate(
                modelBuilder: this.BuildCustomModelBuilder(statisticalSignificanceChecker: new ChiSquareStatisticalSignificanceChecker()),
                modelBuilderParams: new DecisionTreeModelBuilderParams(false, true),
                predictor: predictor,
                qualityMeasure: new ConfusionMatrixBuilder <string>(),
                dataFrame: testData,
                dependentFeatureName: "type",
                percetnagOfTrainData: 0.7,
                folds: 2);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();

            Assert.IsTrue(averageAccuracy >= 0.99);
        }
Example #7
0
        public void RegularizedLinearRegression_ArtificaialFunction()
        {
            // Given
            var splitter = new CrossValidator <double>();
            Func <IList <double>, double> scoreFunc = list => 0.3 + (0.5 * list[0]) + (-0.3 * list[1]) + (0.7 * list[2]);
            var allData =
                TestDataBuilder.BuildRandomAbstractNumericDataFrame(
                    scoreFunc,
                    featuresCount: 3,
                    min: 0,
                    max: 1,
                    rowCount: 1000);
            var subject   = new RegularizedLinearRegressionModelBuilder(0.5);
            var regParams = new LinearRegressionParams(0.05);

            // When
            var accuracies = splitter.CrossValidate(
                modelBuilder: subject,
                modelBuilderParams: regParams,
                predictor: new LinearRegressionPredictor(),
                qualityMeasure: new GoodnessOfFitQualityMeasure(),
                dataFrame: allData,
                dependentFeatureName: "result",
                percetnagOfTrainData: 0.8,
                folds: 20);

            // Then
            Assert.IsTrue(accuracies.Select(acc => acc.Accuracy).Average() >= 0.9);
        }
        public void DiscreteClassification_DiscreteFeatures_MultiValuesSplits_CongressVoting()
        {
            // Given
            var randomForestBuilder = new RandomForestModelBuilder<object>(
                multiValueTreeBuilderWithBetterNumercValsHandler,
                new DecisionTreePredictor<object>(),
                new ConfusionMatrixBuilder<object>(),
                i => (int)Math.Round(Math.Sqrt(i), MidpointRounding.AwayFromZero),
                () => new DecisionTreeModelBuilderParams(false));
            var randomForestPredictor = new RandomForestPredictor<object>(new DecisionTreePredictor<object>(), true);
            var testData = TestDataBuilder.ReadCongressData();
            var crossValidator = new CrossValidator<object>();

            // When
            var accuracy = crossValidator.CrossValidate(
                randomForestBuilder,
                new RandomForestParams(100, 10),
                randomForestPredictor,
                new ConfusionMatrixBuilder<object>(),
                testData,
                "party",
                0.7,
                1).First();

            // Then
            Assert.IsTrue(accuracy.Accuracy >= 0.9);
        }
Example #9
0
        public void DiscreteClassification_DiscreteFeatures_MultiValuesSplits_CongressVoting()
        {
            // Given
            var randomForestBuilder = new RandomForestModelBuilder <object>(
                multiValueTreeBuilderWithBetterNumercValsHandler,
                new DecisionTreePredictor <object>(),
                new ConfusionMatrixBuilder <object>(),
                i => (int)Math.Round(Math.Sqrt(i), MidpointRounding.AwayFromZero),
                () => new DecisionTreeModelBuilderParams(false));
            var randomForestPredictor = new RandomForestPredictor <object>(new DecisionTreePredictor <object>(), true);
            var testData       = TestDataBuilder.ReadCongressData();
            var crossValidator = new CrossValidator <object>();


            // When
            var accuracy = crossValidator.CrossValidate(
                randomForestBuilder,
                new RandomForestParams(100, 10),
                randomForestPredictor,
                new ConfusionMatrixBuilder <object>(),
                testData,
                "party",
                0.7,
                1).First();

            // Then
            Assert.IsTrue(accuracy.Accuracy >= 0.9);
        }
Example #10
0
        static void Main(string[] args)
        {
            var dataset     = MLNetUtilities.GetDataPathByDatasetName("SalaryData.csv");
            var testDataset = MLNetUtilities.GetDataPathByDatasetName("SalaryData-test.csv");

            var pipeline = new LearningPipeline
            {
                new TextLoader(dataset).CreateFrom <SalaryData>(useHeader: true, separator: ','),
                new ColumnConcatenator("Features", "YearsExperience"),
                new GeneralizedAdditiveModelRegressor()
            };

            var crossValidator = new CrossValidator()
            {
                Kind     = MacroUtilsTrainerKinds.SignatureRegressorTrainer,
                NumFolds = 5
            };
            var crossValidatorOutput = crossValidator.CrossValidate <SalaryData, SalaryPrediction>(pipeline);

            Console.Write(Environment.NewLine);
            Console.WriteLine("Root Mean Squared for each fold:");
            crossValidatorOutput.RegressionMetrics.ForEach(m => Console.WriteLine(m.Rms));

            var totalR2  = crossValidatorOutput.RegressionMetrics.Sum(metric => metric.RSquared);
            var totalRMS = crossValidatorOutput.RegressionMetrics.Sum(metric => metric.Rms);

            Console.Write(Environment.NewLine);
            Console.WriteLine($"Average R^2: {totalR2 / crossValidatorOutput.RegressionMetrics.Count}");
            Console.WriteLine($"Average RMS: {totalRMS / crossValidatorOutput.RegressionMetrics.Count}");

            Console.ReadLine();
        }
Example #11
0
        public void DiscreteClassification_CategoricalFeatures_MultiValuesSplits_CongressVotingData_CrossValidation()
        {
            // Given
            var randomizer = new Random();
            var splitter   = new CrossValidator <string>(randomizer);
            var testData   = TestDataBuilder.ReadCongressData() as DataFrame;

            var predictor = new DecisionTreePredictor <string>();

            // When
            var accuracies = splitter.CrossValidate(modelBuilder: this.multiValueTreeBuilder, modelBuilderParams: modelBuilderParams, predictor: predictor, qualityMeasure: new ConfusionMatrixBuilder <string>(), dataFrame: testData, dependentFeatureName: "party", percetnagOfTrainData: 0.7, folds: 10);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();

            Assert.IsTrue(averageAccuracy >= 0.9);
        }
        public void DiscreteClassification_CategoricalFeatures_BinarySplits_ConvressVotingData_CrossValidation()
        {
            // Given
            var randomizer = new Random(3);
            var splitter = new CrossValidator<string>(randomizer);
            var testData = TestDataBuilder.ReadCongressData() as DataFrame;

            var predictor = new DecisionTreePredictor<string>();

            // When
            var accuracies = splitter.CrossValidate(
                modelBuilder: binaryTreeBuilder,
                modelBuilderParams: modelBuilderParams,
                predictor: predictor,
                qualityMeasure: new ConfusionMatrixBuilder<string>(),
                dataFrame: testData,
                dependentFeatureName: "party",
                percetnagOfTrainData: 0.7,
                folds: 10);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();
            Assert.IsTrue(averageAccuracy >= 0.9);
        }
Example #13
0
        public void DiscreteClassification_NumericFeatures_MultiValuesSplits_AdultCensusData_CrossValidation()
        {
            // Given
            var splitter = new CrossValidator <object>();
            var testData = TestDataBuilder.ReadAdultCensusDataFrame();

            var predictor = new DecisionTreePredictor <object>();

            // When
            var accuracies = splitter.CrossValidate(
                multiValueTreeBuilderWithBetterNumercValsHandler,
                modelBuilderParams,
                predictor,
                new ConfusionMatrixBuilder <object>(),
                testData,
                "income",
                0.7,
                5);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();

            Assert.IsTrue(averageAccuracy >= 0.8);
        }
Example #14
0
        public void DiscreteClassification_NumericFeatures_BinarySplits_IrisData_CrossValidation()
        {
            // Given
            var randomizer = new Random();
            var splitter   = new CrossValidator <object>();
            var testData   = TestDataBuilder.ReadIrisData();
            var predictor  = new DecisionTreePredictor <object>();

            // When
            var accuracies = splitter.CrossValidate(
                binaryTreeBuilder,
                modelBuilderParams,
                predictor,
                new ConfusionMatrixBuilder <object>(),
                testData,
                "iris_class",
                0.7,
                10);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();

            Assert.IsTrue(averageAccuracy >= 0.9);
        }
        public void Regression_NumericAttrsAndOutcomesOnly_RegularizedRegression()
        {
            // Given
            var randomizer = new Random(3);
            var splitter = new CrossValidator<double>(randomizer);
            var testData = TestDataBuilder.ReadHousingDataNormalizedAttrs();

            var predictor = new DecisionTreePredictor<double>();

            var numericTreeBuilder = new BinaryDecisionTreeModelBuilder(
                new VarianceBasedSplitQualityChecker(),
                new BestSplitSelectorForNumericValues(new BinaryNumericDataSplitter()),
                new RegressionAndModelDecisionTreeLeafBuilder(new RegularizedLinearRegressionModelBuilder(0.005)));

            // When
            var accuracies = splitter.CrossValidate(
                modelBuilder: numericTreeBuilder,
                modelBuilderParams: modelBuilderParams,
                predictor: predictor,
                qualityMeasure: new GoodnessOfFitQualityMeasure(),
                dataFrame: testData,
                dependentFeatureName: "MEDV",
                percetnagOfTrainData: 0.7,
                folds: 15);

            // Then
            var averegeRsquared = accuracies.Select(report => report.Accuracy).Average();
            Assert.IsTrue(averegeRsquared >= 0.6);
        }
        public void Mushroom_MultiSplit_StatisticalSignificanceHeuristic()
        {
            // Given
            var randomizer = new Random(3);
            var splitter = new CrossValidator<string>(randomizer);
            var testData = TestDataBuilder.ReadMushroomDataWithCategoricalAttributes();

            var predictor = new DecisionTreePredictor<string>();

            // When
            var accuracies = splitter.CrossValidate(
                modelBuilder: this.BuildCustomModelBuilder(statisticalSignificanceChecker: new ChiSquareStatisticalSignificanceChecker()),
                modelBuilderParams: new DecisionTreeModelBuilderParams(false, true),
                predictor: predictor,
                qualityMeasure: new ConfusionMatrixBuilder<string>(),
                dataFrame: testData,
                dependentFeatureName: "type",
                percetnagOfTrainData: 0.7,
                folds: 2);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();
            Assert.IsTrue(averageAccuracy >= 0.99);
        }
        public void Mushroom_MultiSplit()
        {
            // Given
            var randomizer = new Random(3);
            var splitter = new CrossValidator<string>(randomizer);
            var testData = TestDataBuilder.ReadMushroomDataWithCategoricalAttributes();

            var predictor = new DecisionTreePredictor<string>();

            // When
            var accuracies = splitter.CrossValidate(
                modelBuilder: multiValueTreeBuilder,
                modelBuilderParams: modelBuilderParams,
                predictor: predictor,
                qualityMeasure: new ConfusionMatrixBuilder<string>(),
                dataFrame: testData,
                dependentFeatureName: "type",
                percetnagOfTrainData: 0.7,
                folds: 2);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();
            Assert.IsTrue(averageAccuracy >= 0.99);
        }
        public void DiscreteClassification_NumericFeatures_BinarySplits_IrisData_CrossValidation()
        {
            // Given
            var randomizer = new Random();
            var splitter = new CrossValidator<object>();
            var testData = TestDataBuilder.ReadIrisData();
            var predictor = new DecisionTreePredictor<object>();

            // When
            var accuracies = splitter.CrossValidate(
                binaryTreeBuilder,
                modelBuilderParams,
                predictor,
                new ConfusionMatrixBuilder<object>(),
                testData,
                "iris_class",
                0.7,
                10);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();
            Assert.IsTrue(averageAccuracy >= 0.9);
        }
        public void DiscreteClassification_MixedFeatures_MultiValueSplits_CleanedTitanicData()
        {
            // Given
            var randomForestBuilder = new RandomForestModelBuilder<object>(
                multiValueTreeBuilderWithBetterNumercValsHandler,
                new DecisionTreePredictor<object>(),
                new ConfusionMatrixBuilder<object>(),
                i => (int)Math.Round(Math.Sqrt(i), MidpointRounding.AwayFromZero),
                () => new DecisionTreeModelBuilderParams(false, true));
            var randomForestPredictor = new RandomForestPredictor<object>(new DecisionTreePredictor<object>());
            var baseData = TestDataBuilder.ReadTitanicData();
            baseData = baseData.GetSubsetByColumns(baseData.ColumnNames.Except(new[] { "FarePerPerson", "PassengerId", "FamilySize" }).ToList());
            var crossValidator = new CrossValidator<object>();

            // When
            var accuracy = crossValidator.CrossValidate(
                randomForestBuilder,
                new RandomForestParams(200, 10),
                randomForestPredictor,
                new ConfusionMatrixBuilder<object>(),
                baseData,
                "Survived",
                0.75,
                1);

            // Then
            Assert.IsTrue(accuracy.Select(acc => acc.Accuracy).Average() >= 0.75);

            /*
            var qualityMeasure = new ConfusionMatrixBuilder<object>();
            IPredictionModel bestModel = null;
            double accuracy = Double.NegativeInfinity;
            var percetnagOfTrainData = 0.8;

            var trainingDataCount = (int)Math.Round(percetnagOfTrainData * baseData.RowCount);
            var testDataCount = baseData.RowCount - trainingDataCount;
            for (var i = 0; i < 10; i++)
            {
                var shuffledAllIndices = baseData.RowIndices.Shuffle(new Random());
                var trainingIndices = shuffledAllIndices.Take(trainingDataCount).ToList();
                var trainingData = baseData.GetSubsetByRows(trainingIndices);

                var testIndices = shuffledAllIndices.Except(trainingIndices).ToList();
                var testData = baseData.GetSubsetByRows(testIndices);
                IPredictionModel model = randomForestBuilder.BuildModel(trainingData, "Survived", new RandomForestParams(250, 10));
                IList<object> evalPredictions = randomForestPredictor.Predict(testData, model, "Survived");
                IList<object> expected = testData.GetColumnVector<object>("Survived");
                IDataQualityReport<object> qualityReport = qualityMeasure.GetReport(expected, evalPredictions);
                if (qualityReport.Accuracy > accuracy)
                {
                    accuracy = qualityReport.Accuracy;
                    bestModel = model;
                }
            }

            var queryData = TestDataBuilder.ReadTitanicQuery();
            var predictions = randomForestPredictor.Predict(queryData, bestModel, "Survived").Select(elem => (double)Convert.ChangeType(elem, typeof(double))).ToList();
            var passengerIds = queryData.GetNumericColumnVector("PassengerId");

            var matrix = Matrix.Build.DenseOfColumns(new List<IEnumerable<double>>() { passengerIds, predictions });
            DelimitedWriter.Write(@"c:\Users\Filip\Downloads\prediction.csv", matrix, ",");
            Assert.IsTrue(true);
            */
        }
Example #20
0
        public void DiscreteClassification_MixedFeatures_MultiValueSplits_CleanedTitanicData()
        {
            // Given
            var randomForestBuilder = new RandomForestModelBuilder <object>(
                multiValueTreeBuilderWithBetterNumercValsHandler,
                new DecisionTreePredictor <object>(),
                new ConfusionMatrixBuilder <object>(),
                i => (int)Math.Round(Math.Sqrt(i), MidpointRounding.AwayFromZero),
                () => new DecisionTreeModelBuilderParams(false, true));
            var randomForestPredictor = new RandomForestPredictor <object>(new DecisionTreePredictor <object>());
            var baseData = TestDataBuilder.ReadTitanicData();

            baseData = baseData.GetSubsetByColumns(baseData.ColumnNames.Except(new[] { "FarePerPerson", "PassengerId", "FamilySize" }).ToList());
            var crossValidator = new CrossValidator <object>();

            // When
            var accuracy = crossValidator.CrossValidate(
                randomForestBuilder,
                new RandomForestParams(200, 10),
                randomForestPredictor,
                new ConfusionMatrixBuilder <object>(),
                baseData,
                "Survived",
                0.75,
                1);

            // Then
            Assert.IsTrue(accuracy.Select(acc => acc.Accuracy).Average() >= 0.75);

            /*
             * var qualityMeasure = new ConfusionMatrixBuilder<object>();
             * IPredictionModel bestModel = null;
             * double accuracy = Double.NegativeInfinity;
             * var percetnagOfTrainData = 0.8;
             *
             * var trainingDataCount = (int)Math.Round(percetnagOfTrainData * baseData.RowCount);
             * var testDataCount = baseData.RowCount - trainingDataCount;
             * for (var i = 0; i < 10; i++)
             * {
             *  var shuffledAllIndices = baseData.RowIndices.Shuffle(new Random());
             *  var trainingIndices = shuffledAllIndices.Take(trainingDataCount).ToList();
             *  var trainingData = baseData.GetSubsetByRows(trainingIndices);
             *
             *  var testIndices = shuffledAllIndices.Except(trainingIndices).ToList();
             *  var testData = baseData.GetSubsetByRows(testIndices);
             *  IPredictionModel model = randomForestBuilder.BuildModel(trainingData, "Survived", new RandomForestParams(250, 10));
             *  IList<object> evalPredictions = randomForestPredictor.Predict(testData, model, "Survived");
             *  IList<object> expected = testData.GetColumnVector<object>("Survived");
             *  IDataQualityReport<object> qualityReport = qualityMeasure.GetReport(expected, evalPredictions);
             *  if (qualityReport.Accuracy > accuracy)
             *  {
             *      accuracy = qualityReport.Accuracy;
             *      bestModel = model;
             *  }
             * }
             *
             * var queryData = TestDataBuilder.ReadTitanicQuery();
             * var predictions = randomForestPredictor.Predict(queryData, bestModel, "Survived").Select(elem => (double)Convert.ChangeType(elem, typeof(double))).ToList();
             * var passengerIds = queryData.GetNumericColumnVector("PassengerId");
             *
             * var matrix = Matrix.Build.DenseOfColumns(new List<IEnumerable<double>>() { passengerIds, predictions });
             * DelimitedWriter.Write(@"c:\Users\Filip\Downloads\prediction.csv", matrix, ",");
             * Assert.IsTrue(true);
             */
        }
        public void DiscreteClassification_CategoricalFeatures_MultiValuesSplits_CongressVotingData_StatisticalSignificanceHeuristic_CrossValidation()
        {
            // Given
            var randomizer = new Random();
            var splitter = new CrossValidator<string>(randomizer);
            var testData = TestDataBuilder.ReadCongressData() as DataFrame;

            var predictor = new DecisionTreePredictor<string>();

            // When
            var accuracies = splitter.CrossValidate(
                modelBuilder: this.BuildCustomModelBuilder(true, statisticalSignificanceChecker: new ChiSquareStatisticalSignificanceChecker(0.05)),
                modelBuilderParams: new DecisionTreeModelBuilderParams(false, true),
                predictor: predictor,
                qualityMeasure: new ConfusionMatrixBuilder<string>(),
                dataFrame: testData,
                dependentFeatureName: "party",
                percetnagOfTrainData: 0.7,
                folds: 10);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();
            Assert.IsTrue(averageAccuracy >= 0.9);
        }
        public void DiscreteClassification_NumericFeatures_MultiValuesSplits_AdultCensusData_CrossValidation()
        {
            // Given
            var splitter = new CrossValidator<object>();
            var testData = TestDataBuilder.ReadAdultCensusDataFrame();

            var predictor = new DecisionTreePredictor<object>();

            // When
            var accuracies = splitter.CrossValidate(
                multiValueTreeBuilderWithBetterNumercValsHandler,
                modelBuilderParams,
                predictor,
                new ConfusionMatrixBuilder<object>(),
                testData,
                "income",
                0.7,
                5);

            // Then
            var averageAccuracy = accuracies.Select(report => report.Accuracy).Average();
            Assert.IsTrue(averageAccuracy >= 0.8);
        }