public void PipelineSweeperNoTransforms()
        {
            // Set up inputs for experiment
            string       pathData        = GetDataPath("adult.train");
            string       pathDataTest    = GetDataPath("adult.test");
            const int    numOfSampleRows = 1000;
            const string schema          = "sep=, col=Features:R4:0,2,4,10-12 col=Label:R4:14 header=+";

            var inputFileTrain = new SimpleFileHandle(Env, pathData, false, false);

#pragma warning disable 0618
            var datasetTrain = ImportTextData.ImportText(Env,
                                                         new ImportTextData.Input {
                InputFile = inputFileTrain, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
            var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false);
            var datasetTest   = ImportTextData.ImportText(Env,
                                                          new ImportTextData.Input {
                InputFile = inputFileTest, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
#pragma warning restore 0618
            const int       batchSize          = 5;
            const int       numIterations      = 20;
            const int       numTransformLevels = 2;
            var             env    = new MLContext();
            SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);

            // Using the simple, uniform random sampling (with replacement) engine
            PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(Env);

            // Create search object
            var amls = new AutoInference.AutoMlMlState(Env, metric, autoMlEngine, new IterationTerminator(numIterations),
                                                       MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer, datasetTrain, datasetTest);

            // Infer search space
            amls.InferSearchSpace(numTransformLevels);

            // Create macro object
            var pipelineSweepInput = new Microsoft.ML.Legacy.Models.PipelineSweeper()
            {
                BatchSize = batchSize,
            };

            var exp    = new Experiment(Env);
            var output = exp.Add(pipelineSweepInput);
            exp.Compile();
            exp.SetInput(pipelineSweepInput.TrainingData, datasetTrain);
            exp.SetInput(pipelineSweepInput.TestingData, datasetTest);
            exp.SetInput(pipelineSweepInput.State, amls);
            exp.SetInput(pipelineSweepInput.CandidateOutputs, new IDataView[0]);
            exp.Run();

            // Make sure you get back an AutoMlState, and that it ran for correct number of iterations
            // with at least minimal performance values (i.e., best should have AUC better than 0.1 on this dataset).
            AutoInference.AutoMlMlState amlsOut = (AutoInference.AutoMlMlState)exp.GetOutput(output.State);
            Assert.NotNull(amlsOut);
            Assert.Equal(amlsOut.GetAllEvaluatedPipelines().Length, numIterations);
            Assert.True(amlsOut.GetBestPipeline().PerformanceSummary.MetricValue > 0.8);
        }
Beispiel #2
0
        public void TestSaveAndLoadTreeFeaturizer()
        {
            int dataPointCount = 200;
            var data           = SamplesUtils.DatasetUtils.GenerateFloatLabelFloatFeatureVectorSamples(dataPointCount).ToList();
            var dataView       = ML.Data.LoadFromEnumerable(data);

            dataView = ML.Data.Cache(dataView);

            var trainerOptions = new FastForestRegressionTrainer.Options
            {
                NumberOfThreads            = 1,
                NumberOfTrees              = 10,
                NumberOfLeaves             = 4,
                MinimumExampleCountPerLeaf = 10,
                FeatureColumnName          = "Features",
                LabelColumnName            = "Label"
            };

            var options = new FastForestRegressionFeaturizationEstimator.Options()
            {
                InputColumnName  = "Features",
                TreesColumnName  = "Trees",
                LeavesColumnName = "Leaves",
                PathsColumnName  = "Paths",
                TrainerOptions   = trainerOptions
            };

            var pipeline = ML.Transforms.FeaturizeByFastForestRegression(options)
                           .Append(ML.Transforms.Concatenate("CombinedFeatures", "Features", "Trees", "Leaves", "Paths"))
                           .Append(ML.Regression.Trainers.Sdca("Label", "CombinedFeatures"));
            var model      = pipeline.Fit(dataView);
            var prediction = model.Transform(dataView);
            var metrics    = ML.Regression.Evaluate(prediction);

            Assert.True(metrics.MeanAbsoluteError < 0.25);
            Assert.True(metrics.MeanSquaredError < 0.1);

            // Save the trained model into file.
            ITransformer loadedModel = null;
            var          tempPath    = Path.GetTempFileName();

            using (var file = new SimpleFileHandle(Env, tempPath, true, true))
            {
                using (var fs = file.CreateWriteStream())
                    ML.Model.Save(model, null, fs);

                using (var fs = file.OpenReadStream())
                    loadedModel = ML.Model.Load(fs, out var schema);
            }
            var loadedPrediction = loadedModel.Transform(dataView);
            var loadedMetrics    = ML.Regression.Evaluate(loadedPrediction);

            Assert.Equal(metrics.MeanAbsoluteError, loadedMetrics.MeanAbsoluteError);
            Assert.Equal(metrics.MeanSquaredError, loadedMetrics.MeanSquaredError);
        }
Beispiel #3
0
        public void TestLearn()
        {
            using (var env = new LocalEnvironment()
                             .AddStandardComponents()) // AutoInference.InferPipelines uses ComponentCatalog to read text data
            {
                string          pathData           = GetDataPath("adult.train");
                string          pathDataTest       = GetDataPath("adult.test");
                int             numOfSampleRows    = 1000;
                int             batchSize          = 5;
                int             numIterations      = 10;
                int             numTransformLevels = 3;
                SupportedMetric metric             = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);

                // Using the simple, uniform random sampling (with replacement) engine
                PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(env);

                // Test initial learning
                var amls = AutoInference.InferPipelines(env, autoMlEngine, pathData, "", out var schema, numTransformLevels, batchSize,
                                                        metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations / 2), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer);
                env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations / 2);

                // Resume learning
                amls.UpdateTerminator(new IterationTerminator(numIterations));
                bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows);
                env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations);

                // Use best pipeline for another task
                var inputFileTrain = new SimpleFileHandle(env, pathData, false, false);
#pragma warning disable 0618
                var datasetTrain = ImportTextData.ImportText(env,
                                                             new ImportTextData.Input {
                    InputFile = inputFileTrain, CustomSchema = schema
                }).Data;
                var inputFileTest = new SimpleFileHandle(env, pathDataTest, false, false);
                var datasetTest   = ImportTextData.ImportText(env,
                                                              new ImportTextData.Input {
                    InputFile = inputFileTest, CustomSchema = schema
                }).Data;
#pragma warning restore 0618

                // REVIEW: Theoretically, it could be the case that a new, very bad learner is introduced and
                // we get unlucky and only select it every time, such that this test fails. Not
                // likely at all, but a non-zero probability. Should be ok, since all current learners are returning d > .80.
                bestPipeline.RunTrainTestExperiment(datasetTrain, datasetTest, metric, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer,
                                                    out var testMetricValue, out var trainMtericValue);
                env.Check(testMetricValue > 0.2);
            }
            Done();
        }
Beispiel #4
0
        public void TestEstimatorSaveLoad()
        {
            IHostEnvironment env = new MLContext(1);
            var dataFile         = GetDataPath("images/images.tsv");
            var imageFolder      = Path.GetDirectoryName(dataFile);
            var data             = TextLoader.Create(env, new TextLoader.Options()
            {
                Columns = new[]
                {
                    new TextLoader.Column("ImagePath", DataKind.String, 0),
                    new TextLoader.Column("Name", DataKind.String, 1),
                }
            }, new MultiFileSource(dataFile));

            var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImageReal", "ImagePath"))
                       .Append(new ImageResizingEstimator(env, "ImageReal", 100, 100, "ImageReal"))
                       .Append(new ImagePixelExtractingEstimator(env, "ImagePixels", "ImageReal"))
                       .Append(new ImageGrayscalingEstimator(env, ("ImageGray", "ImageReal")));

            pipe.GetOutputSchema(SchemaShape.Create(data.Schema));
            var model = pipe.Fit(data);

            var tempPath = Path.GetTempFileName();

            using (var file = new SimpleFileHandle(env, tempPath, true, true))
            {
                using (var fs = file.CreateWriteStream())
                    ML.Model.Save(model, null, fs);
                ITransformer model2;
                using (var fs = file.OpenReadStream())
                    model2 = ML.Model.Load(fs, out var schema);

                var transformerChain = model2 as TransformerChain <ITransformer>;
                Assert.NotNull(transformerChain);

                var newCols = ((ImageLoadingTransformer)transformerChain.First()).Columns;
                var oldCols = ((ImageLoadingTransformer)model.First()).Columns;
                Assert.True(newCols
                            .Zip(oldCols, (x, y) => x == y)
                            .All(x => x));
            }
            Done();
        }
Beispiel #5
0
        private IDataView GetBreastCancerDataviewWithTextColumns()
        {
            var dataPath  = GetDataPath("breast-cancer.txt");
            var inputFile = new SimpleFileHandle(Env, dataPath, false, false);

            return(ImportTextData.TextLoader(Env, new ImportTextData.LoaderInput()
            {
                Arguments =
                {
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", type: null, 0),
                        new TextLoader.Column("F1", DataKind.Text, 1),
                        new TextLoader.Column("F2", DataKind.I4, 2),
                        new TextLoader.Column("Rest", type: null, new [] { new TextLoader.Range(3, 9) })
                    }
                },

                InputFile = inputFile
            }).Data);
        }
        public void TestEstimatorSaveLoad()
        {
            IHostEnvironment env = new MLContext();
            var dataFile         = GetDataPath("images/images.tsv");
            var imageFolder      = Path.GetDirectoryName(dataFile);
            var data             = TextLoader.Create(env, new TextLoader.Arguments()
            {
                Column = new[]
                {
                    new TextLoader.Column("ImagePath", DataKind.TX, 0),
                    new TextLoader.Column("Name", DataKind.TX, 1),
                }
            }, new MultiFileSource(dataFile));

            var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal"))
                       .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100))
                       .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels"))
                       .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray")));

            pipe.GetOutputSchema(Core.Data.SchemaShape.Create(data.Schema));
            var model = pipe.Fit(data);

            var tempPath = Path.GetTempFileName();

            using (var file = new SimpleFileHandle(env, tempPath, true, true))
            {
                using (var fs = file.CreateWriteStream())
                    model.SaveTo(env, fs);
                var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream());

                var newCols = ((ImageLoaderTransform)model2.First()).Columns;
                var oldCols = ((ImageLoaderTransform)model.First()).Columns;
                Assert.True(newCols
                            .Zip(oldCols, (x, y) => x == y)
                            .All(x => x));
            }
            Done();
        }
Beispiel #7
0
        public void SetInputFromPath(GraphRunner runner, string varName, string path, TlcModule.DataKind kind)
        {
            _host.CheckUserArg(runner != null, nameof(runner), "Provide a GraphRunner instance.");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(varName), nameof(varName), "Specify a graph variable name.");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(path), nameof(path), "Specify a valid file path.");

            switch (kind)
            {
            case TlcModule.DataKind.FileHandle:
                var fh = new SimpleFileHandle(_host, path, false, false);
                runner.SetInput(varName, fh);
                break;

            case TlcModule.DataKind.DataView:
                IDataView loader = new BinaryLoader(_host, new BinaryLoader.Arguments(), path);
                runner.SetInput(varName, loader);
                break;

            case TlcModule.DataKind.PredictorModel:
                PredictorModelImpl pm;
                using (var fs = File.OpenRead(path))
                    pm = new PredictorModelImpl(_host, fs);
                runner.SetInput(varName, pm);
                break;

            case TlcModule.DataKind.TransformModel:
                TransformModelImpl tm;
                using (var fs = File.OpenRead(path))
                    tm = new TransformModelImpl(_host, fs);
                runner.SetInput(varName, tm);
                break;

            default:
                throw _host.Except("Port type {0} not supported", kind);
            }
        }
Beispiel #8
0
        public void PipelineSweeperRocketEngine()
        {
            // Get datasets
            var          pathData        = GetDataPath("adult.train");
            var          pathDataTest    = GetDataPath("adult.test");
            const int    numOfSampleRows = 1000;
            int          numIterations   = 35;
            const string schema          =
                "sep=, col=Features:R4:0,2,4,10-12 col=workclass:TX:1 col=education:TX:3 col=marital_status:TX:5 col=occupation:TX:6 " +
                "col=relationship:TX:7 col=ethnicity:TX:8 col=sex:TX:9 col=native_country:TX:13 col=label_IsOver50K_:R4:14 header=+";
            var inputFileTrain = new SimpleFileHandle(Env, pathData, false, false);

#pragma warning disable 0618
            var datasetTrain = ImportTextData.ImportText(Env,
                                                         new ImportTextData.Input {
                InputFile = inputFileTrain, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
            var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false);
            var datasetTest   = ImportTextData.ImportText(Env,
                                                          new ImportTextData.Input {
                InputFile = inputFileTest, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
#pragma warning restore 0618
            // Define entrypoint graph
            string inputGraph = @"
                {
                  'Nodes': [                                
                    {
                      'Name': 'Models.PipelineSweeper',
                      'Inputs': {
                        'TrainingData': '$TrainingData',
                        'TestingData': '$TestingData',
                        'StateArguments': {
                            'Name': 'AutoMlState',
                            'Settings': {
                                'Metric': 'Auc',
                                'Engine': {
                                    'Name': 'Rocket',
                                    'Settings' : {
                                        'TopKLearners' : 2,
                                        'SecondRoundTrialsPerLearner' : 5
                                    },
                                },
                                'TerminatorArgs': {
                                    'Name': 'IterationLimited',
                                    'Settings': {
                                        'FinalHistoryLength': 35
                                    }
                                },
                                'TrainerKind': 'SignatureBinaryClassifierTrainer'
                            }
                        },
                        'BatchSize': 5
                      },
                      'Outputs': {
                        'State': '$StateOut',
                        'Results': '$ResultsOut'
                      }
                    },
                  ]
                }";

            JObject graph   = JObject.Parse(inputGraph);
            var     catalog = Env.ComponentCatalog;

            var runner = new GraphRunner(Env, catalog, graph[FieldNames.Nodes] as JArray);
            runner.SetInput("TrainingData", datasetTrain);
            runner.SetInput("TestingData", datasetTest);
            runner.RunAll();

            var autoMlState = runner.GetOutput <AutoInference.AutoMlMlState>("StateOut");
            Assert.NotNull(autoMlState);
            var allPipelines = autoMlState.GetAllEvaluatedPipelines();
            var bestPipeline = autoMlState.GetBestPipeline();
            Assert.Equal(allPipelines.Length, numIterations);
            Assert.True(bestPipeline.PerformanceSummary.MetricValue > 0.1);

            var results = runner.GetOutput <IDataView>("ResultsOut");
            Assert.NotNull(results);
            var rows = PipelinePattern.ExtractResults(Env, results,
                                                      "Graph", "MetricValue", "PipelineId", "TrainingMetricValue", "FirstInput", "PredictorModel");
            Assert.True(rows.Length == numIterations);
        }
Beispiel #9
0
        public void PipelineSweeperMultiClassClassification()
        {
            // Get datasets
            // TODO (agoswami) : For now we use the same dataset for train and test since the repo does not have a separate test file for the iris dataset.
            // In the future the PipelineSweeper Macro will have an option to take just one dataset as input, and do the train-test split internally.
            var          pathData       = GetDataPath(@"iris.txt");
            var          pathDataTest   = GetDataPath(@"iris.txt");
            int          numIterations  = 2;
            const string schema         = "col=Species:R4:0 col=SepalLength:R4:1 col=SepalWidth:R4:2 col=PetalLength:R4:3 col=PetalWidth:R4:4";
            var          inputFileTrain = new SimpleFileHandle(Env, pathData, false, false);

#pragma warning disable 0618
            var datasetTrain = ImportTextData.ImportText(Env, new ImportTextData.Input {
                InputFile = inputFileTrain, CustomSchema = schema
            }).Data;
            var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false);
            var datasetTest   = ImportTextData.ImportText(Env, new ImportTextData.Input {
                InputFile = inputFileTest, CustomSchema = schema
            }).Data;
#pragma warning restore 0618

            // Define entrypoint graph
            string inputGraph = @"
                {
                  'Nodes': [
                    {
                      'Name': 'Models.PipelineSweeper',
                      'Inputs': {
                        'TrainingData': '$TrainingData',
                        'TestingData': '$TestingData',
                        'LabelColumns': ['Species'],
                        'StateArguments': {
                            'Name': 'AutoMlState',
                            'Settings': {
                                'Metric': 'AccuracyMicro',
                                'Engine': {
                                    'Name': 'Defaults'
                                },
                                'TerminatorArgs': {
                                    'Name': 'IterationLimited',
                                    'Settings': {
                                        'FinalHistoryLength': 2
                                    }
                                },
                                'TrainerKind': 'SignatureMultiClassClassifierTrainer',
                                'RequestedLearners' : [
                                    'LogisticRegressionClassifier',
                                    'StochasticDualCoordinateAscentClassifier'
                                ]
                            }
                        },
                        'BatchSize': 1
                      },
                      'Outputs': {
                        'State': '$StateOut',
                        'Results': '$ResultsOut'
                      }
                    },
                  ]
                }";

            JObject graphJson = JObject.Parse(inputGraph);
            var     catalog   = Env.ComponentCatalog;
            var     runner    = new GraphRunner(Env, catalog, graphJson[FieldNames.Nodes] as JArray);
            runner.SetInput("TrainingData", datasetTrain);
            runner.SetInput("TestingData", datasetTest);
            runner.RunAll();

            var autoMlState = runner.GetOutput <AutoInference.AutoMlMlState>("StateOut");
            Assert.NotNull(autoMlState);
            var allPipelines = autoMlState.GetAllEvaluatedPipelines();
            var bestPipeline = autoMlState.GetBestPipeline();
            Assert.Equal(allPipelines.Length, numIterations);

            var bestMicroAccuracyTrain = bestPipeline.PerformanceSummary.TrainingMetricValue;
            var bestMicroAccuracyTest  = bestPipeline.PerformanceSummary.MetricValue;
            Assert.True((0.97 < bestMicroAccuracyTrain) && (bestMicroAccuracyTrain < 0.99));
            Assert.True((0.97 < bestMicroAccuracyTest) && (bestMicroAccuracyTest < 0.99));

            var results = runner.GetOutput <IDataView>("ResultsOut");
            Assert.NotNull(results);
            var rows = PipelinePattern.ExtractResults(Env, results,
                                                      "Graph", "MetricValue", "PipelineId", "TrainingMetricValue", "FirstInput", "PredictorModel");
            Assert.True(rows.Length == numIterations);
            Assert.True(rows.All(r => r.MetricValue > 0.9));
        }
Beispiel #10
0
        public void PipelineSweeperBasic()
        {
            // Get datasets
            var       pathData        = GetDataPath("adult.tiny.with-schema.txt");
            var       pathDataTest    = GetDataPath("adult.tiny.with-schema.txt");
            const int numOfSampleRows = 1000;
            int       numIterations   = 4;
            var       inputFileTrain  = new SimpleFileHandle(Env, pathData, false, false);

#pragma warning disable 0618
            var datasetTrain = ImportTextData.ImportText(Env,
                                                         new ImportTextData.Input {
                InputFile = inputFileTrain
            }).Data.Take(numOfSampleRows);
            var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false);
            var datasetTest   = ImportTextData.ImportText(Env,
                                                          new ImportTextData.Input {
                InputFile = inputFileTest
            }).Data.Take(numOfSampleRows);
#pragma warning restore 0618
            // Define entrypoint graph
            string inputGraph = @"
                {
                  'Nodes': [                                
                    {
                      'Name': 'Models.PipelineSweeper',
                      'Inputs': {
                        'TrainingData': '$TrainingData',
                        'TestingData': '$TestingData',
                        'StateArguments': {
                            'Name': 'AutoMlState',
                            'Settings': {
                                'Metric': 'Auc',
                                'Engine': {
                                    'Name': 'UniformRandom'
                                },
                                'TerminatorArgs': {
                                    'Name': 'IterationLimited',
                                    'Settings': {
                                        'FinalHistoryLength': 4
                                    }
                                },
                                'TrainerKind': 'SignatureBinaryClassifierTrainer'
                            }
                        },
                        'BatchSize': 2
                      },
                      'Outputs': {
                        'State': '$StateOut',
                        'Results': '$ResultsOut'
                      }
                    },
                  ]
                }";

            JObject graph   = JObject.Parse(inputGraph);
            var     catalog = Env.ComponentCatalog;

            var runner = new GraphRunner(Env, catalog, graph[FieldNames.Nodes] as JArray);
            runner.SetInput("TrainingData", datasetTrain);
            runner.SetInput("TestingData", datasetTest);
            runner.RunAll();

            var autoMlState = runner.GetOutput <AutoInference.AutoMlMlState>("StateOut");
            Assert.NotNull(autoMlState);
            var allPipelines = autoMlState.GetAllEvaluatedPipelines();
            var bestPipeline = autoMlState.GetBestPipeline();
            Assert.Equal(allPipelines.Length, numIterations);
            Assert.True(bestPipeline.PerformanceSummary.MetricValue > 0.1);

            var results = runner.GetOutput <IDataView>("ResultsOut");
            Assert.NotNull(results);
            var rows = PipelinePattern.ExtractResults(Env, results,
                                                      "Graph", "MetricValue", "PipelineId", "TrainingMetricValue", "FirstInput", "PredictorModel");
            Assert.True(rows.Length == numIterations);
            Assert.True(rows.All(r => r.TrainingMetricValue > 0.1));
        }
Beispiel #11
0
        public void PipelineSweeperRoles()
        {
            // Get datasets
            var          pathData        = GetDataPath("adult.train");
            var          pathDataTest    = GetDataPath("adult.test");
            const int    numOfSampleRows = 100;
            int          numIterations   = 2;
            const string schema          =
                "sep=, col=age:R4:0 col=workclass:TX:1 col=fnlwgt:R4:2 col=education:TX:3 col=education_num:R4:4 col=marital_status:TX:5 col=occupation:TX:6 " +
                "col=relationship:TX:7 col=ethnicity:TX:8 col=sex:TX:9 col=Features:R4:10-12 col=native_country:TX:13 col=IsOver50K_:R4:14 header=+";
            var inputFileTrain = new SimpleFileHandle(Env, pathData, false, false);

#pragma warning disable 0618
            var datasetTrain = ImportTextData.ImportText(Env,
                                                         new ImportTextData.Input {
                InputFile = inputFileTrain, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
            var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false);
            var datasetTest   = ImportTextData.ImportText(Env,
                                                          new ImportTextData.Input {
                InputFile = inputFileTest, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
#pragma warning restore 0618

            // Define entrypoint graph
            string inputGraph = @"
                {
                  'Nodes': [
                    {
                      'Name': 'Models.PipelineSweeper',
                      'Inputs': {
                        'TrainingData': '$TrainingData',
                        'TestingData': '$TestingData',
                        'LabelColumns': ['IsOver50K_'],
                        'WeightColumns': ['education_num'],
                        'NameColumns': ['education'],
                        'TextFeatureColumns': ['workclass', 'marital_status', 'occupation'],
                        'StateArguments': {
                            'Name': 'AutoMlState',
                            'Settings': {
                                'Metric': 'Auc',
                                'Engine': {
                                    'Name': 'Defaults'
                                },
                                'TerminatorArgs': {
                                    'Name': 'IterationLimited',
                                    'Settings': {
                                        'FinalHistoryLength': 2
                                    }
                                },
                                'TrainerKind': 'SignatureBinaryClassifierTrainer',
                                'RequestedLearners' : [
                                    'LogisticRegressionBinaryClassifier',
                                    'FastTreeBinaryClassifier'
                                ]
                            }
                        },
                        'BatchSize': 1
                      },
                      'Outputs': {
                        'State': '$StateOut',
                        'Results': '$ResultsOut'
                      }
                    },
                  ]
                }";

            JObject graphJson = JObject.Parse(inputGraph);
            var     catalog   = Env.ComponentCatalog;
            var     runner    = new GraphRunner(Env, catalog, graphJson[FieldNames.Nodes] as JArray);
            runner.SetInput("TrainingData", datasetTrain);
            runner.SetInput("TestingData", datasetTest);
            runner.RunAll();

            var autoMlState = runner.GetOutput <AutoInference.AutoMlMlState>("StateOut");
            Assert.NotNull(autoMlState);
            var allPipelines = autoMlState.GetAllEvaluatedPipelines();
            var bestPipeline = autoMlState.GetBestPipeline();
            Assert.Equal(allPipelines.Length, numIterations);

            var trainAuc = bestPipeline.PerformanceSummary.TrainingMetricValue;
            var testAuc  = bestPipeline.PerformanceSummary.MetricValue;
            Assert.True((0.94 < trainAuc) && (trainAuc < 0.95));
            Assert.True((0.815 < testAuc) && (testAuc < 0.825));

            var results = runner.GetOutput <IDataView>("ResultsOut");
            Assert.NotNull(results);
            var rows = PipelinePattern.ExtractResults(Env, results,
                                                      "Graph", "MetricValue", "PipelineId", "TrainingMetricValue", "FirstInput", "PredictorModel");
            Assert.True(rows.Length == numIterations);
            Assert.True(rows.All(r => r.TrainingMetricValue > 0.1));
        }
Beispiel #12
0
        public void CanSuccessfullyRetrieveSparseData()
        {
            string dataPath   = GetDataPath("SparseData.txt");
            string inputGraph = @"
            {
                'Nodes':
                [{
                        'Name': 'Data.TextLoader',
                        'Inputs': {
                            'InputFile': '$inputFile',
                            'Arguments': {
                                'UseThreads': true,
                                'HeaderFile': null,
                                'MaxRows': null,
                                'AllowQuoting': false,
                                'AllowSparse': true,
                                'InputSize': null,
                                'Separator': [
                                    '\t'
                                ],
                                'Column': [{
                                        'Name': 'C1',
                                        'Type': 'R4',
                                        'Source': [{
                                                'Min': 0,
                                                'Max': 0,
                                                'AutoEnd': false,
                                                'VariableEnd': false,
                                                'AllOther': false,
                                                'ForceVector': false
                                            }
                                        ],
                                        'KeyCount': null
                                    }, {
                                        'Name': 'C2',
                                        'Type': 'R4',
                                        'Source': [{
                                                'Min': 1,
                                                'Max': 1,
                                                'AutoEnd': false,
                                                'VariableEnd': false,
                                                'AllOther': false,
                                                'ForceVector': false
                                            }
                                        ],
                                        'KeyCount': null
                                    }, {
                                        'Name': 'C3',
                                        'Type': 'R4',
                                        'Source': [{
                                                'Min': 2,
                                                'Max': 2,
                                                'AutoEnd': false,
                                                'VariableEnd': false,
                                                'AllOther': false,
                                                'ForceVector': false
                                            }
                                        ],
                                        'KeyCount': null
                                    }, {
                                        'Name': 'C4',
                                        'Type': 'R4',
                                        'Source': [{
                                                'Min': 3,
                                                'Max': 3,
                                                'AutoEnd': false,
                                                'VariableEnd': false,
                                                'AllOther': false,
                                                'ForceVector': false
                                            }
                                        ],
                                        'KeyCount': null
                                    }, {
                                        'Name': 'C5',
                                        'Type': 'R4',
                                        'Source': [{
                                                'Min': 4,
                                                'Max': 4,
                                                'AutoEnd': false,
                                                'VariableEnd': false,
                                                'AllOther': false,
                                                'ForceVector': false
                                            }
                                        ],
                                        'KeyCount': null
                                    }
                                ],
                                'TrimWhitespace': false,
                                'HasHeader': true
                            }
                        },
                        'Outputs': {
                            'Data': '$data'
                        }
                    }
                ]
            }";

            JObject graph     = JObject.Parse(inputGraph);
            var     runner    = new GraphRunner(_env, graph[FieldNames.Nodes] as JArray);
            var     inputFile = new SimpleFileHandle(_env, dataPath, false, false);

            runner.SetInput("inputFile", inputFile);
            runner.RunAll();

            var data = runner.GetOutput <IDataView>("data");

            Assert.NotNull(data);

            using (var cursor = data.GetRowCursorForAllColumns())
            {
                var getters = new ValueGetter <float>[] {
                    cursor.GetGetter <float>(cursor.Schema[0]),
                    cursor.GetGetter <float>(cursor.Schema[1]),
                    cursor.GetGetter <float>(cursor.Schema[2]),
                    cursor.GetGetter <float>(cursor.Schema[3]),
                    cursor.GetGetter <float>(cursor.Schema[4])
                };


                Assert.True(cursor.MoveNext());

                float[] targets = new float[] { 1, 2, 3, 4, 5 };
                for (int i = 0; i < getters.Length; i++)
                {
                    float value = 0;
                    getters[i](ref value);
                    Assert.Equal(targets[i], value);
                }

                Assert.True(cursor.MoveNext());

                targets = new float[] { 0, 0, 0, 4, 5 };
                for (int i = 0; i < getters.Length; i++)
                {
                    float value = 0;
                    getters[i](ref value);
                    Assert.Equal(targets[i], value);
                }

                Assert.True(cursor.MoveNext());

                targets = new float[] { 0, 2, 0, 0, 0 };
                for (int i = 0; i < getters.Length; i++)
                {
                    float value = 0;
                    getters[i](ref value);
                    Assert.Equal(targets[i], value);
                }

                Assert.False(cursor.MoveNext());
            }
        }
Beispiel #13
0
        public void TestSaveAndLoadDoubleTreeFeaturizer()
        {
            int dataPointCount = 200;
            var data           = SamplesUtils.DatasetUtils.GenerateFloatLabelFloatFeatureVectorSamples(dataPointCount).ToList();
            var dataView       = ML.Data.LoadFromEnumerable(data);

            dataView = ML.Data.Cache(dataView);

            var trainerOptions = new FastForestRegressionTrainer.Options
            {
                NumberOfThreads            = 1,
                NumberOfTrees              = 10,
                NumberOfLeaves             = 4,
                MinimumExampleCountPerLeaf = 10,
                FeatureColumnName          = "Features",
                LabelColumnName            = "Label"
            };

            // Trains tree featurization on "Features" and applies on "CopiedFeatures".
            var options = new FastForestRegressionFeaturizationEstimator.Options()
            {
                InputColumnName  = "CopiedFeatures",
                TrainerOptions   = trainerOptions,
                TreesColumnName  = "OhMyTrees",
                LeavesColumnName = "OhMyLeaves",
                PathsColumnName  = "OhMyPaths"
            };

            var pipeline = ML.Transforms.CopyColumns("CopiedFeatures", "Features")
                           .Append(ML.Transforms.FeaturizeByFastForestRegression(options))
                           .Append(ML.Transforms.Concatenate("CombinedFeatures", "Features", "OhMyTrees", "OhMyLeaves", "OhMyPaths"))
                           .Append(ML.Regression.Trainers.Sdca("Label", "CombinedFeatures"));
            var model      = pipeline.Fit(dataView);
            var prediction = model.Transform(dataView);
            var metrics    = ML.Regression.Evaluate(prediction);

            Assert.True(metrics.MeanAbsoluteError < 0.25);
            Assert.True(metrics.MeanSquaredError < 0.1);

            // Save the trained model into file and then load it back.
            ITransformer loadedModel = null;
            var          tempPath    = Path.GetTempFileName();

            using (var file = new SimpleFileHandle(Env, tempPath, true, true))
            {
                using (var fs = file.CreateWriteStream())
                    ML.Model.Save(model, null, fs);

                using (var fs = file.OpenReadStream())
                    loadedModel = ML.Model.Load(fs, out var schema);
            }

            // Compute prediction using the loaded model.
            var loadedPrediction = loadedModel.Transform(dataView);
            var loadedMetrics    = ML.Regression.Evaluate(loadedPrediction);

            // Check if the loaded model produces the same result as the trained model.
            Assert.Equal(metrics.MeanAbsoluteError, loadedMetrics.MeanAbsoluteError);
            Assert.Equal(metrics.MeanSquaredError, loadedMetrics.MeanSquaredError);

            var secondPipeline = ML.Transforms.CopyColumns("CopiedFeatures", "Features")
                                 .Append(ML.Transforms.NormalizeBinning("CopiedFeatures"))
                                 .Append(ML.Transforms.FeaturizeByFastForestRegression(options))
                                 .Append(ML.Transforms.Concatenate("CombinedFeatures", "Features", "OhMyTrees", "OhMyLeaves", "OhMyPaths"))
                                 .Append(ML.Regression.Trainers.Sdca("Label", "CombinedFeatures"));
            var secondModel      = secondPipeline.Fit(dataView);
            var secondPrediction = secondModel.Transform(dataView);
            var secondMetrics    = ML.Regression.Evaluate(secondPrediction);

            // The second pipeline trains a tree featurizer on a bin-based normalized feature, so the second pipeline
            // is different from the first pipeline.
            Assert.NotEqual(metrics.MeanAbsoluteError, secondMetrics.MeanAbsoluteError);
            Assert.NotEqual(metrics.MeanSquaredError, secondMetrics.MeanSquaredError);
        }
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = Args.Trainer.CreateComponent(Host);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing the training pipeline");
            IDataView trainPipe = CreateLoader();

            ISchema schema = trainPipe.Schema;
            string  label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                                                                 Args.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
                                                                  Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                                                               Args.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                                                                Args.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn),
                                                              Args.NameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures);

            ch.Trace("Binding columns");
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var data       = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols);

            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
            {
                if (!trainer.Info.SupportsValidation)
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
                    validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            // In addition to the training set, some trainers can accept two data sets, validation set and test set,
            // in training phase. The major difference between validation set and test set is that training process may
            // indirectly use validation set to improve the model but the learned model should totally independent of test set.
            // Similar to validation set, the trainer can report the scores computed using test set.
            RoleMappedData testDataUsedInTrainer = null;

            if (!string.IsNullOrWhiteSpace(Args.TestFile))
            {
                // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided
                // because this is TrainTest command.
                if (trainer.Info.SupportsTest)
                {
                    ch.Trace("Constructing the test pipeline");
                    IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: Args.TestFile);
                    testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer);
                    testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
                                             Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer);

            IDataLoader testPipe;
            bool        hasOutfile   = !string.IsNullOrEmpty(Args.OutputModelFile);
            var         tempFilePath = hasOutfile ? null : Path.GetTempFileName();

            using (var file = new SimpleFileHandle(ch, hasOutfile ? Args.OutputModelFile : tempFilePath, true, !hasOutfile))
            {
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
                ch.Trace("Constructing the testing pipeline");
                using (var stream = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(stream, ch))
                        testPipe = LoadLoader(rep, Args.TestFile, true);
            }

            // Score.
            ch.Trace("Scoring and evaluating");
            ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line.");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema);

            // Evaluate.
            var evaluator = Args.Evaluator?.CreateComponent(Host) ??
                            EvaluateUtils.GetEvaluator(Host, scorePipe.Schema);
            var dataEval = new RoleMappedData(scorePipe, label, features,
                                              group, weight, name, customCols, opt: true);
            var metrics = evaluator.Evaluate(dataEval);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
            {
                throw ch.Except("No overall metrics found");
            }
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(dataEval);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
            }
        }
Beispiel #15
0
        public void PipelineSweeperRequestedLearners()
        {
            // Get datasets
            var          pathData        = GetDataPath("adult.train");
            var          pathDataTest    = GetDataPath("adult.test");
            const int    numOfSampleRows = 100;
            const string schema          =
                "sep=, col=Features:R4:0,2,4,10-12 col=workclass:TX:1 col=education:TX:3 col=marital_status:TX:5 col=occupation:TX:6 " +
                "col=relationship:TX:7 col=race:TX:8 col=sex:TX:9 col=native_country:TX:13 col=label_IsOver50K_:R4:14 header=+";
            var inputFileTrain = new SimpleFileHandle(Env, pathData, false, false);

#pragma warning disable 0618
            var datasetTrain = ImportTextData.ImportText(Env,
                                                         new ImportTextData.Input {
                InputFile = inputFileTrain, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
            var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false);
            var datasetTest   = ImportTextData.ImportText(Env,
                                                          new ImportTextData.Input {
                InputFile = inputFileTest, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
            var requestedLearners = new[] { $"LogisticRegressionBinaryClassifier", $"FastTreeBinaryClassifier" };
#pragma warning restore 0618
            // Define entrypoint graph
            string inputGraph = @"
                {
                  'Nodes': [                                
                    {
                      'Name': 'Models.PipelineSweeper',
                      'Inputs': {
                        'TrainingData': '$TrainingData',
                        'TestingData': '$TestingData',
                        'StateArguments': {
                            'Name': 'AutoMlState',
                            'Settings': {
                                'Metric': 'Auc',
                                'Engine': {
                                    'Name': 'Rocket',
                                    'Settings' : {
                                        'TopKLearners' : 2,
                                        'SecondRoundTrialsPerLearner' : 0
                                    },
                                },
                                'TerminatorArgs': {
                                    'Name': 'IterationLimited',
                                    'Settings': {
                                        'FinalHistoryLength': 35
                                    }
                                },
                                'TrainerKind': 'SignatureBinaryClassifierTrainer',
                                'RequestedLearners' : [
                                    'LogisticRegressionBinaryClassifier',
                                    'FastTreeBinaryClassifier'
                                ]
                            }
                        },
                        'BatchSize': 5
                      },
                      'Outputs': {
                        'State': '$StateOut',
                        'Results': '$ResultsOut'
                      }
                    },
                  ]
                }";

            JObject graph   = JObject.Parse(inputGraph);
            var     catalog = Env.ComponentCatalog;

            var runner = new GraphRunner(Env, catalog, graph[FieldNames.Nodes] as JArray);
            runner.SetInput("TrainingData", datasetTrain);
            runner.SetInput("TestingData", datasetTest);
            runner.RunAll();

            var autoMlState = runner.GetOutput <AutoInference.AutoMlMlState>("StateOut");
            Assert.NotNull(autoMlState);
            var space = autoMlState.GetSearchSpace();

            // Make sure only learners left are those retained.
            Assert.Equal(requestedLearners.Length, space.Item2.Length);
            Assert.True(space.Item2.All(l => requestedLearners.Any(r => r == l.LearnerName)));
        }
        public void CanSuccessfullyRetrieveQuotedData()
        {
            string dataPath   = GetDataPath("QuotingData.csv");
            string inputGraph = @"
            {  
               'Nodes':[  
                  {  
                     'Name':'Data.TextLoader',
                     'Inputs':{  
                        'InputFile':'$inputFile',
                        'Arguments':{  
                           'UseThreads':true,
                           'HeaderFile':null,
                           'MaxRows':null,
                           'AllowQuoting':true,
                           'AllowSparse':false,
                           'InputSize':null,
                           'Separator':[  
                              ','
                           ],
                           'Column':[  
                              {  
                                 'Name':'ID',
                                 'Type':'R4',
                                 'Source':[  
                                    {  
                                       'Min':0,
                                       'Max':0,
                                       'AutoEnd':false,
                                       'VariableEnd':false,
                                       'AllOther':false,
                                       'ForceVector':false
                                    }
                                 ],
                                 'KeyCount':null
                              },
                              {  
                                 'Name':'Text',
                                 'Type':'TX',
                                 'Source':[  
                                    {  
                                       'Min':1,
                                       'Max':1,
                                       'AutoEnd':false,
                                       'VariableEnd':false,
                                       'AllOther':false,
                                       'ForceVector':false
                                    }
                                 ],
                                 'KeyCount':null
                              }
                           ],
                           'TrimWhitespace':false,
                           'HasHeader':true
                        }
                     },
                     'Outputs':{  
                        'Data':'$data'
                     }
                  }
               ]
            }";

            JObject graph     = JObject.Parse(inputGraph);
            var     runner    = new GraphRunner(env, graph[FieldNames.Nodes] as JArray);
            var     inputFile = new SimpleFileHandle(env, dataPath, false, false);

            runner.SetInput("inputFile", inputFile);
            runner.RunAll();

            var data = runner.GetOutput <IDataView>("data"); Assert.NotNull(data);

            using (var cursor = data.GetRowCursorForAllColumns())
            {
                var IDGetter   = cursor.GetGetter <float>(0);
                var TextGetter = cursor.GetGetter <ReadOnlyMemory <char> >(1);

                Assert.True(cursor.MoveNext());

                float ID = 0;
                IDGetter(ref ID);
                Assert.Equal(1, ID);

                ReadOnlyMemory <char> Text = new ReadOnlyMemory <char>();
                TextGetter(ref Text);
                Assert.Equal("This text contains comma, within quotes.", Text.ToString());

                Assert.True(cursor.MoveNext());

                ID = 0;
                IDGetter(ref ID);
                Assert.Equal(2, ID);

                Text = new ReadOnlyMemory <char>();
                TextGetter(ref Text);
                Assert.Equal("This text contains extra punctuations and special characters.;*<>?!@#$%^&*()_+=-{}|[]:;'", Text.ToString());

                Assert.True(cursor.MoveNext());

                ID = 0;
                IDGetter(ref ID);
                Assert.Equal(3, ID);

                Text = new ReadOnlyMemory <char>();
                TextGetter(ref Text);
                Assert.Equal("This text has no quotes", Text.ToString());

                Assert.False(cursor.MoveNext());
            }
        }
Beispiel #17
0
        private static void RunGraphCore(EnvironmentBlock *penv, IHostEnvironment env, string graphStr, int cdata, DataSourceBlock **ppdata)
        {
            Contracts.AssertValue(env);

            var     host = env.Register("RunGraph", penv->seed, null);
            JObject graph;

            try
            {
                graph = JObject.Parse(graphStr);
            }
            catch (JsonReaderException ex)
            {
                throw host.Except(ex, "Failed to parse experiment graph: {0}", ex.Message);
            }

            var runner = new GraphRunner(host, graph["nodes"] as JArray);

            var dvNative = new IDataView[cdata];

            try
            {
                for (int i = 0; i < cdata; i++)
                {
                    dvNative[i] = new NativeDataView(host, ppdata[i]);
                }

                // Setting inputs.
                var jInputs = graph["inputs"] as JObject;
                if (graph["inputs"] != null && jInputs == null)
                {
                    throw host.Except("Unexpected value for 'inputs': {0}", graph["inputs"]);
                }
                int iDv = 0;
                if (jInputs != null)
                {
                    foreach (var kvp in jInputs)
                    {
                        var pathValue = kvp.Value as JValue;
                        if (pathValue == null)
                        {
                            throw host.Except("Invalid value for input: {0}", kvp.Value);
                        }

                        var path    = pathValue.Value <string>();
                        var varName = kvp.Key;
                        var type    = runner.GetPortDataKind(varName);

                        switch (type)
                        {
                        case TlcModule.DataKind.FileHandle:
                            var fh = new SimpleFileHandle(host, path, false, false);
                            runner.SetInput(varName, fh);
                            break;

                        case TlcModule.DataKind.DataView:
                            IDataView dv;
                            if (!string.IsNullOrWhiteSpace(path))
                            {
                                var extension = Path.GetExtension(path);
                                if (extension == ".txt")
                                {
                                    dv = TextLoader.LoadFile(host, new TextLoader.Options(), new MultiFileSource(path));
                                }
                                else if (extension == ".dprep")
                                {
                                    dv = LoadDprepFile(BytesToString(penv->pythonPath), path);
                                }
                                else
                                {
                                    dv = new BinaryLoader(host, new BinaryLoader.Arguments(), path);
                                }
                            }
                            else
                            {
                                Contracts.Assert(iDv < dvNative.Length);
                                // prefetch all columns
                                dv = dvNative[iDv++];
                                var prefetch = new int[dv.Schema.Count];
                                for (int i = 0; i < prefetch.Length; i++)
                                {
                                    prefetch[i] = i;
                                }
                                dv = new CacheDataView(host, dv, prefetch);
                            }
                            runner.SetInput(varName, dv);
                            break;

                        case TlcModule.DataKind.PredictorModel:
                            PredictorModel pm;
                            if (!string.IsNullOrWhiteSpace(path))
                            {
                                using (var fs = File.OpenRead(path))
                                    pm = new PredictorModelImpl(host, fs);
                            }
                            else
                            {
                                throw host.Except("Model must be loaded from a file");
                            }
                            runner.SetInput(varName, pm);
                            break;

                        case TlcModule.DataKind.TransformModel:
                            TransformModel tm;
                            if (!string.IsNullOrWhiteSpace(path))
                            {
                                using (var fs = File.OpenRead(path))
                                    tm = new TransformModelImpl(host, fs);
                            }
                            else
                            {
                                throw host.Except("Model must be loaded from a file");
                            }
                            runner.SetInput(varName, tm);
                            break;

                        default:
                            throw host.Except("Port type {0} not supported", type);
                        }
                    }
                }
                runner.RunAll();

                // Reading outputs.
                using (var ch = host.Start("Reading outputs"))
                {
                    var jOutputs = graph["outputs"] as JObject;
                    if (jOutputs != null)
                    {
                        foreach (var kvp in jOutputs)
                        {
                            var pathValue = kvp.Value as JValue;
                            if (pathValue == null)
                            {
                                throw host.Except("Invalid value for input: {0}", kvp.Value);
                            }
                            var path    = pathValue.Value <string>();
                            var varName = kvp.Key;
                            var type    = runner.GetPortDataKind(varName);

                            switch (type)
                            {
                            case TlcModule.DataKind.FileHandle:
                                var fh = runner.GetOutput <IFileHandle>(varName);
                                throw host.ExceptNotSupp("File handle outputs not yet supported.");

                            case TlcModule.DataKind.DataView:
                                var idv = runner.GetOutput <IDataView>(varName);
                                if (path == CSR_MATRIX)
                                {
                                    SendViewToNativeAsCsr(ch, penv, idv);
                                }
                                else if (!string.IsNullOrWhiteSpace(path))
                                {
                                    SaveIdvToFile(idv, path, host);
                                }
                                else
                                {
                                    var infos = ProcessColumns(ref idv, penv->maxSlots, host);
                                    SendViewToNativeAsDataFrame(ch, penv, idv, infos);
                                }
                                break;

                            case TlcModule.DataKind.PredictorModel:
                                var pm = runner.GetOutput <PredictorModel>(varName);
                                if (!string.IsNullOrWhiteSpace(path))
                                {
                                    SavePredictorModelToFile(pm, path, host);
                                }
                                else
                                {
                                    throw host.Except("Returning in-memory models is not supported");
                                }
                                break;

                            case TlcModule.DataKind.TransformModel:
                                var tm = runner.GetOutput <TransformModel>(varName);
                                if (!string.IsNullOrWhiteSpace(path))
                                {
                                    using (var fs = File.OpenWrite(path))
                                        tm.Save(host, fs);
                                }
                                else
                                {
                                    throw host.Except("Returning in-memory models is not supported");
                                }
                                break;

                            case TlcModule.DataKind.Array:
                                var objArray = runner.GetOutput <object[]>(varName);
                                if (objArray is PredictorModel[])
                                {
                                    var modelArray = (PredictorModel[])objArray;
                                    // Save each model separately
                                    for (var i = 0; i < modelArray.Length; i++)
                                    {
                                        var modelPath = string.Format(CultureInfo.InvariantCulture, path, i);
                                        SavePredictorModelToFile(modelArray[i], modelPath, host);
                                    }
                                }
                                else
                                {
                                    throw host.Except("DataKind.Array type {0} not supported", objArray.First().GetType());
                                }
                                break;

                            default:
                                throw host.Except("Port type {0} not supported", type);
                            }
                        }
                    }
                }
            }
            finally
            {
                // The raw data view is disposable so it lets go of unmanaged raw pointers before we return.
                for (int i = 0; i < dvNative.Length; i++)
                {
                    var view = dvNative[i];
                    if (view == null)
                    {
                        continue;
                    }
                    host.Assert(view is IDisposable);
                    var disp = (IDisposable)dvNative[i];
                    disp.Dispose();
                }
            }
        }
Beispiel #18
0
        public void CanSuccessfullyApplyATransform()
        {
            string inputGraph = @"
            {
                'Nodes':
                [{
                        'Name': 'Data.TextLoader',
                        'Inputs': {
                            'InputFile': '$inputFile',
                            'Arguments': {
                                'UseThreads': true,
                                'HeaderFile': null,
                                'MaxRows': null,
                                'AllowQuoting': true,
                                'AllowSparse': true,
                                'InputSize': null,
                                'Separator': [
                                    '\t'
                                ],
                                'Column': [{
                                        'Name': 'String1',
                                        'Type': 'TX',
                                        'Source': [{
                                                'Min': 0,
                                                'Max': 0,
                                                'AutoEnd': false,
                                                'VariableEnd': false,
                                                'AllOther': false,
                                                'ForceVector': false
                                            }
                                        ],
                                        'KeyCount': null
                                    }, {
                                        'Name': 'Number1',
                                        'Type': 'R4',
                                        'Source': [{
                                                'Min': 1,
                                                'Max': 1,
                                                'AutoEnd': false,
                                                'VariableEnd': false,
                                                'AllOther': false,
                                                'ForceVector': false
                                            }
                                        ],
                                        'KeyCount': null
                                    }
                                ],
                                'TrimWhitespace': false,
                                'HasHeader': false
                            }
                        },
                        'Outputs': {
                            'Data': '$data'
                        }
                    }
                ]
            }";

            JObject graph     = JObject.Parse(inputGraph);
            var     runner    = new GraphRunner(_env, graph[FieldNames.Nodes] as JArray);
            var     inputFile = new SimpleFileHandle(_env, "fakeFile.txt", false, false);

            runner.SetInput("inputFile", inputFile);
            runner.RunAll();

            var data = runner.GetOutput <IDataView>("data");

            Assert.NotNull(data);
        }
Beispiel #19
0
        public void TestOnnxTransformSaveAndLoadWithCustomShapes()
        {
            // The loaded model has input shape [-1, 3] and output shape [-1].
            var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "unknowndimensions", "test_unknowndimensions_float.onnx");

            var dataPoints = new InputWithCustomShape[]
            {
                // It's a flattened 3-by-3 tensor.
                // [1.1, 1.3, 1.2]
                // |1.9, 1.3, 1.2|
                // [1.1, 1.3, 1.8]
                new InputWithCustomShape()
                {
                    input = new float[] { 1.1f, 1.3f, 1.2f, 1.9f, 1.3f, 1.2f, 1.1f, 1.3f, 1.8f }
                },
                // It's a flattened 3-by-3 tensor.
                // [0, 0, 1]
                // |1, 0, 0|
                // [1, 0, 0]
                new InputWithCustomShape()
                {
                    input = new float[] { 0f, 0f, 1f, 1f, 0f, 0f, 1f, 0f, 0f }
                }
            };

            var shapeDictionary = new Dictionary <string, int[]>()
            {
                { nameof(InputWithCustomShape.input), new int[] { 3, 3 } }
            };

            var dataView = ML.Data.LoadFromEnumerable(dataPoints);

            var pipeline = ML.Transforms.ApplyOnnxModel(nameof(PredictionWithCustomShape.argmax),
                                                        nameof(InputWithCustomShape.input), modelFile, shapeDictionary);

            var model = pipeline.Fit(dataView);

            // Save the trained ONNX transformer into file and then load it back.
            ITransformer loadedModel = null;
            var          tempPath    = Path.GetTempFileName();

            using (var file = new SimpleFileHandle(Env, tempPath, true, true))
            {
                // Save.
                using (var fs = file.CreateWriteStream())
                    ML.Model.Save(model, null, fs);

                // Load.
                using (var fs = file.OpenReadStream())
                    loadedModel = ML.Model.Load(fs, out var schema);
            }

            var transformedDataView = loadedModel.Transform(dataView);

            // Conduct the same check for all the 3 called public APIs.
            var transformedDataPoints = ML.Data.CreateEnumerable <PredictionWithCustomShape>(transformedDataView, false).ToList();

            // One data point generates one transformed data point.
            Assert.Equal(dataPoints.Count(), transformedDataPoints.Count);

            // Check result numbers. They are results of applying ONNX argmax along the second axis; for example
            // [1.1, 1.3, 1.2] ---> [1] because 1.3 (indexed by 1) is the largest element.
            // |1.9, 1.3, 1.2| ---> |0|         1.9             0
            // [1.1, 1.3, 1.8] ---> [2]         1.8             2
            var expectedResults = new long[][]
            {
                new long[] { 1, 0, 2 },
                new long[] { 2, 0, 0 }
            };

            for (int i = 0; i < transformedDataPoints.Count; ++i)
            {
                Assert.Equal(transformedDataPoints[i].argmax, expectedResults[i]);
            }

            (model as IDisposable)?.Dispose();
            (loadedModel as IDisposable)?.Dispose();
        }
Beispiel #20
0
        public void CanSuccessfullyTrimSpaces()
        {
            string dataPath   = GetDataPath("TrimData.csv");
            string inputGraph = @"{
                'Nodes':
                [{
                        'Name': 'Data.TextLoader',
                        'Inputs': {
                            'InputFile': '$inputFile',
                            'Arguments': {
                                'UseThreads': true,
                                'HeaderFile': null,
                                'MaxRows': null,
                                'AllowQuoting': false,
                                'AllowSparse': false,
                                'InputSize': null,
                                'Separator': [
                                    ','
                                ],
                                'Column': [{
                                        'Name': 'ID',
                                        'Type': 'R4',
                                        'Source': [{
                                                'Min': 0,
                                                'Max': 0,
                                                'AutoEnd': false,
                                                'VariableEnd': false,
                                                'AllOther': false,
                                                'ForceVector': false
                                            }
                                        ],
                                        'KeyCount': null
                                    }, {
                                        'Name': 'Text',
                                        'Type': 'TX',
                                        'Source': [{
                                                'Min': 1,
                                                'Max': 1,
                                                'AutoEnd': false,
                                                'VariableEnd': false,
                                                'AllOther': false,
                                                'ForceVector': false
                                            }
                                        ],
                                        'KeyCount': null
                                    }
                                ],
                                'TrimWhitespace': true,
                                'HasHeader': true
                            }
                        },
                        'Outputs': {
                            'Data': '$data'
                        }
                    }
                ]
            }";

            JObject graph     = JObject.Parse(inputGraph);
            var     runner    = new GraphRunner(_env, graph[FieldNames.Nodes] as JArray);
            var     inputFile = new SimpleFileHandle(_env, dataPath, false, false);

            runner.SetInput("inputFile", inputFile);
            runner.RunAll();

            var data = runner.GetOutput <IDataView>("data");

            Assert.NotNull(data);

            using (var cursor = data.GetRowCursorForAllColumns())
            {
                var idGetter   = cursor.GetGetter <float>(cursor.Schema[0]);
                var textGetter = cursor.GetGetter <ReadOnlyMemory <char> >(cursor.Schema[1]);

                Assert.True(cursor.MoveNext());

                float id = 0;
                idGetter(ref id);
                Assert.Equal(1, id);

                ReadOnlyMemory <char> text = new ReadOnlyMemory <char>();
                textGetter(ref text);
                Assert.Equal("There is a space at the end", text.ToString());

                Assert.True(cursor.MoveNext());

                id = 0;
                idGetter(ref id);
                Assert.Equal(2, id);

                text = new ReadOnlyMemory <char>();
                textGetter(ref text);
                Assert.Equal("There is no space at the end", text.ToString());

                Assert.False(cursor.MoveNext());
            }
        }
Beispiel #21
0
        public void PipelineSweeperSerialization()
        {
            // Get datasets
            var          pathData        = GetDataPath("adult.train");
            var          pathDataTest    = GetDataPath("adult.test");
            const int    numOfSampleRows = 1000;
            int          numIterations   = 10;
            const string schema          =
                "sep=, col=Features:R4:0,2,4,10-12 col=workclass:TX:1 col=education:TX:3 col=marital_status:TX:5 col=occupation:TX:6 " +
                "col=relationship:TX:7 col=ethnicity:TX:8 col=sex:TX:9 col=native_country:TX:13 col=label_IsOver50K_:R4:14 header=+";
            var inputFileTrain = new SimpleFileHandle(Env, pathData, false, false);

#pragma warning disable 0618
            var datasetTrain = ImportTextData.ImportText(Env,
                                                         new ImportTextData.Input {
                InputFile = inputFileTrain, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
            var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false);
            var datasetTest   = ImportTextData.ImportText(Env,
                                                          new ImportTextData.Input {
                InputFile = inputFileTest, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
#pragma warning restore 0618

            // Define entrypoint graph
            string inputGraph = @"
                {
                  'Nodes': [
                    {
                      'Name': 'Models.PipelineSweeper',
                      'Inputs': {
                        'TrainingData': '$TrainingData',
                        'TestingData': '$TestingData',
                        'StateArguments': {
                            'Name': 'AutoMlState',
                            'Settings': {
                                'Metric': 'Auc',
                                'Engine': {
                                    'Name': 'UniformRandom'
                                },
                                'TerminatorArgs': {
                                    'Name': 'IterationLimited',
                                    'Settings': {
                                        'FinalHistoryLength': 10
                                    }
                                },
                                'TrainerKind': 'SignatureBinaryClassifierTrainer'
                            }
                        },
                        'BatchSize': 5
                      },
                      'Outputs': {
                        'State': '$StateOut',
                        'Results': '$ResultsOut'
                      }
                    },
                  ]
                }";

            JObject graphJson = JObject.Parse(inputGraph);
            var     catalog   = Env.ComponentCatalog;
            var     graph     = new EntryPointGraph(Env, catalog, graphJson[FieldNames.Nodes] as JArray);
            // Test if ToJson() works properly.
            var nodes  = new JArray(graph.AllNodes.Select(node => node.ToJson()));
            var runner = new GraphRunner(Env, catalog, nodes);
            runner.SetInput("TrainingData", datasetTrain);
            runner.SetInput("TestingData", datasetTest);
            runner.RunAll();

            var results = runner.GetOutput <IDataView>("ResultsOut");
            Assert.NotNull(results);
            var rows = PipelinePattern.ExtractResults(Env, results,
                                                      "Graph", "MetricValue", "PipelineId", "TrainingMetricValue", "FirstInput", "PredictorModel");
            Assert.True(rows.Length == numIterations);
        }