Esempio n. 1
0
        public void AutoNormalizationAndCaching()
        {
            var data = GetDataPath(TestDatasets.Sentiment.trainFilename);

            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline.
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(data));

                var trans = TextFeaturizingEstimator.Create(env, MakeSentimentTextTransformArgs(false), loader);

                // Train.
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads           = 1,
                    ConvergenceTolerance = 1f
                });

                // Auto-caching.
                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, trans, prefetch: null) : trans;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
            }
        }
Esempio n. 2
0
        public void SetupSentimentPipeline()
        {
            _sentimentExample = new SentimentData()
            {
                SentimentText = "Not a big fan of this."
            };

            string _sentimentDataPath = Program.GetInvariantCultureDataPath("wikipedia-detox-250-line-data.tsv");

            var env    = new MLContext(seed: 1, conc: 1);
            var reader = new TextLoader(env, columns: new[]
            {
                new TextLoader.Column("Label", DataKind.BL, 0),
                new TextLoader.Column("SentimentText", DataKind.Text, 1)
            },
                                        hasHeader: true
                                        );

            IDataView data = reader.Read(_sentimentDataPath);

            var pipeline = new TextFeaturizingEstimator(env, "SentimentText", "Features")
                           .Append(new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; }));

            var model = pipeline.Fit(data);

            _sentimentModel = model.CreatePredictionEngine <SentimentData, SentimentPrediction>(env);
        }
        public void SetupSentimentPipeline()
        {
            _sentimentExample = new SentimentData()
            {
                SentimentText = "Not a big fan of this."
            };

            string _sentimentDataPath = BaseTestClass.GetDataPath("wikipedia-detox-250-line-data.tsv");

            var env    = new MLContext(seed: 1, conc: 1);
            var reader = new TextLoader(env, columns: new[]
            {
                new TextLoader.Column("Label", DataKind.BL, 0),
                new TextLoader.Column("SentimentText", DataKind.Text, 1)
            },
                                        hasHeader: true
                                        );

            IDataView data = reader.Read(_sentimentDataPath);

            var pipeline = new TextFeaturizingEstimator(env, "Features", "SentimentText")
                           .Append(env.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
                                       new SdcaBinaryTrainer.Options {
                NumThreads = 1, ConvergenceTolerance = 1e-2f,
            }));

            var model = pipeline.Fit(data);

            _sentimentModel = model.CreatePredictionEngine <SentimentData, SentimentPrediction>(env);
        }
Esempio n. 4
0
        public void TestWordEmbeddings()
        {
            var dataPath     = GetDataPath(TestDatasets.Sentiment.trainFilename);
            var testDataPath = GetDataPath(TestDatasets.Sentiment.testFilename);

            var data = TextLoaderStatic.CreateReader(Env, ctx => (
                                                         label: ctx.LoadBool(0),
                                                         SentimentText: ctx.LoadText(1)), hasHeader: true)
                       .Read(dataPath);
            var dynamicData = new TextFeaturizingEstimator(Env, "SentimentText", "SentimentText_Features", args =>
            {
                args.OutputTokens     = true;
                args.KeepPunctuations = false;
                args.UseStopRemover   = true;
                args.VectorNormalizer = TextFeaturizingEstimator.TextNormKind.None;
                args.UseCharExtractor = false;
                args.UseWordExtractor = false;
            }).Fit(data.AsDynamic).Transform(data.AsDynamic);
            var data2 = dynamicData.AssertStatic(Env, ctx => (
                                                     SentimentText_Features_TransformedText: ctx.Text.VarVector,
                                                     SentimentText: ctx.Text.Scalar,
                                                     label: ctx.Bool.Scalar));

            var est = data2.MakeNewEstimator()
                      .Append(row => row.SentimentText_Features_TransformedText.WordEmbeddings());

            TestEstimatorCore(est.AsDynamic, data2.AsDynamic, invalidInput: data.AsDynamic);
            Done();
        }
Esempio n. 5
0
        public void SetupSentimentPipeline()
        {
            _sentimentExample = new SentimentData()
            {
                SentimentText = "Not a big fan of this."
            };

            string _sentimentDataPath = Program.GetInvariantCultureDataPath("wikipedia-detox-250-line-data.tsv");

            using (var env = new ConsoleEnvironment(seed: 1, conc: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance))
            {
                var reader = new TextLoader(env,
                                            new TextLoader.Arguments()
                {
                    Separator = "\t",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.BL, 0),
                        new TextLoader.Column("SentimentText", DataKind.Text, 1)
                    }
                });

                IDataView data = reader.Read(_sentimentDataPath);

                var pipeline = new TextFeaturizingEstimator(env, "SentimentText", "Features")
                               .Append(new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; }));

                var model = pipeline.Fit(data);

                _sentimentModel = model.MakePredictionFunction <SentimentData, SentimentPrediction>(env);
            }
        }
Esempio n. 6
0
        public void TrainWithValidationSet()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var trans     = TextFeaturizingEstimator.Create(env, MakeSentimentTextTransformArgs(), loader);
                var trainData = trans;

                // Apply the same transformations on the validation set.
                // Sadly, there is no way to easily apply the same loader to different data, so we either have
                // to create another loader, or to save the loader to model file and then reload.

                // A new one is not always feasible, but this time it is.
                var validLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename)));
                var validData   = ApplyTransformUtils.ApplyAllTransformsToData(env, trainData, validLoader);

                // Cache both datasets.
                var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
                var cachedValid = new CacheDataView(env, validData, prefetch: null);

                // Train.
                var trainer    = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numTrees: 3);
                var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
                var validRoles = new RoleMappedData(cachedValid, label: "Label", feature: "Features");
                trainer.Train(new Runtime.TrainContext(trainRoles, validRoles));
            }
        }
        void FileBasedSavingOfData()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var trans = TextFeaturizingEstimator.Create(env, MakeSentimentTextTransformArgs(), loader);
                var saver = new BinarySaver(env, new BinarySaver.Arguments());
                using (var ch = env.Start("SaveData"))
                    using (var file = env.CreateOutputFile("i.idv"))
                    {
                        DataSaverUtils.SaveDataView(ch, saver, trans, file);
                    }

                var binData    = new BinaryLoader(env, new BinaryLoader.Arguments(), new MultiFileSource("i.idv"));
                var trainRoles = new RoleMappedData(binData, label: "Label", feature: "Features");
                var trainer    = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                DeleteOutputPath("i.idv");
            }
        }
        private IDataScorerTransform _TrainSentiment()
        {
            bool normalize = true;

            var args = new TextLoader.Arguments()
            {
                Separator = "tab",
                HasHeader = true,
                Column    = new[] {
                    new TextLoader.Column("Label", DataKind.BL, 0),
                    new TextLoader.Column("SentimentText", DataKind.Text, 1)
                }
            };

            var args2 = new TextFeaturizingEstimator.Arguments()
            {
                Column = new TextFeaturizingEstimator.Column
                {
                    Name   = "Features",
                    Source = new[] { "SentimentText" }
                },
                KeepDiacritics               = false,
                KeepPunctuations             = false,
                TextCase                     = TextNormalizingEstimator.CaseNormalizationMode.Lower,
                OutputTokens                 = true,
                UsePredefinedStopWordRemover = true,
                VectorNormalizer             = normalize ? TextFeaturizingEstimator.TextNormKind.L2 : TextFeaturizingEstimator.TextNormKind.None,
                CharFeatureExtractor         = new NgramExtractorTransform.NgramExtractorArguments()
                {
                    NgramLength = 3, AllLengths = false
                },
                WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                {
                    NgramLength = 2, AllLengths = true
                },
            };

            var trainFilename = FileHelper.GetTestFile("wikipedia-detox-250-line-data.tsv");

            using (var env = EnvHelper.NewTestEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env, args).Read(new MultiFileSource(trainFilename));
                var trans  = TextFeaturizingEstimator.Create(env, args2, loader);

                // Train
                var trainer = new SdcaBinaryTrainer(env, new SdcaBinaryTrainer.Arguments
                {
                    NumThreads = 1
                });

                var cached    = new CacheDataView(env, trans, prefetch: null);
                var predictor = trainer.Fit(cached);

                var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                return(ScoreUtils.GetScorer(predictor.Model, scoreRoles, env, trainRoles.Schema));
            }
        }
Esempio n. 9
0
        private static IDataScorerTransform _TrainSentiment()
        {
            bool normalize = true;

            var args = new TextLoader.Options()
            {
                Separators = new[] { '\t' },
                HasHeader  = true,
                Columns    = new[]
                {
                    new TextLoader.Column("Label", DataKind.Boolean, 0),
                    new TextLoader.Column("SentimentText", DataKind.String, 1)
                }
            };

            var args2 = new TextFeaturizingEstimator.Options()
            {
                KeepDiacritics         = false,
                KeepPunctuations       = false,
                CaseMode               = TextNormalizingEstimator.CaseMode.Lower,
                OutputTokensColumnName = "tokens",
                Norm = normalize ? TextFeaturizingEstimator.NormFunction.L2 : TextFeaturizingEstimator.NormFunction.None,
                CharFeatureExtractor = new WordBagEstimator.Options()
                {
                    NgramLength = 3, UseAllLengths = false
                },
                WordFeatureExtractor = new WordBagEstimator.Options()
                {
                    NgramLength = 2, UseAllLengths = true
                },
            };

            var trainFilename = FileHelper.GetTestFile("wikipedia-detox-250-line-data.tsv");

            /*using (*/
            var env = EnvHelper.NewTestEnvironment(seed: 1, conc: 1);
            {
                // Pipeline
                var loader = new TextLoader(env, args).Load(new MultiFileSource(trainFilename));

                var trans = TextFeaturizingEstimator.Create(env, args2, loader);

                // Train
                var trainer = new SdcaLogisticRegressionBinaryTrainer(env, new SdcaLogisticRegressionBinaryTrainer.Options
                {
                    LabelColumnName   = "Label",
                    FeatureColumnName = "Features"
                });

                var cached    = new Microsoft.ML.Data.CacheDataView(env, trans, prefetch: null);
                var predictor = trainer.Fit(cached);

                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                return(ScoreUtils.GetScorer(predictor.Model, scoreRoles, env, trainRoles.Schema));
            }
        }
Esempio n. 10
0
                public static SuggestedTransform TextTransformBigramTriChar(MLContext env, string srcColumn, string dstColumn, Type transformType)
                {
                    var input = new TextFeaturizingEstimator(env, srcColumn, dstColumn)
                    {
                        //WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 2 },
                        //CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 3 }
                    };

                    return(TextTransform(srcColumn, dstColumn, input));
                }
Esempio n. 11
0
        private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
        {
            // STEP 1: Common data loading configuration
            IDataView dataView = mlContext.Data.LoadFromTextFile <SentimentIssue>(DataPath, hasHeader: true);

            DataOperationsCatalog.TrainTestData trainTestSplit =
                mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2);
            IDataView trainingData = trainTestSplit.TrainSet;
            IDataView testData     = trainTestSplit.TestSet;

            // STEP 2: Common data process configuration with pipeline data transformations
            TextFeaturizingEstimator dataProcessPipeline = mlContext.Transforms.Text.FeaturizeText(outputColumnName: "Features",
                                                                                                   inputColumnName: nameof(SentimentIssue.Text));

            // (OPTIONAL) Peek data (such as 2 records) in training DataView after applying the ProcessPipeline's transformations into "Features"
            ConsoleHelper.PeekDataViewInConsole(mlContext, dataView, dataProcessPipeline, 2);
            //Peak the transformed features column
            //ConsoleHelper.PeekVectorColumnDataInConsole(mlContext, "Features", dataView, dataProcessPipeline, 1);

            // STEP 3: Set the training algorithm, then create and config the modelBuilder
            SdcaLogisticRegressionBinaryTrainer trainer =
                mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(labelColumnName: "Label",
                                                                               featureColumnName: "Features");
            EstimatorChain <BinaryPredictionTransformer <CalibratedModelParametersBase <LinearBinaryModelParameters, PlattCalibrator> > > trainingPipeline = dataProcessPipeline.Append(trainer);

            //Measure training time
            Stopwatch watch = Stopwatch.StartNew();

            // STEP 4: Train the model fitting to the DataSet
            Console.WriteLine("=============== Training the model ===============");
            ITransformer trainedModel = trainingPipeline.Fit(trainingData);

            //Stop measuring time
            watch.Stop();
            long elapsedMs = watch.ElapsedMilliseconds;

            Console.WriteLine($"***** Training time: {elapsedMs / 1000} seconds *****");

            // STEP 5: Evaluate the model and show accuracy stats
            Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
            IDataView predictions = trainedModel.Transform(testData);
            CalibratedBinaryClassificationMetrics metrics = mlContext.BinaryClassification.Evaluate(data: predictions, labelColumnName: "Label",
                                                                                                    scoreColumnName: "Score");

            ConsoleHelper.PrintBinaryClassificationMetrics(trainer.ToString(), metrics);

            // STEP 6: Save/persist the trained model to a .ZIP file
            mlContext.Model.Save(trainedModel, trainingData.Schema, ModelPath);

            Console.WriteLine("The model is saved to {0}", ModelPath);

            return(trainedModel);
        }
        void ReconfigurablePrediction()
        {
            var data = GetDataPath(TestDatasets.Sentiment.trainFilename);

            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(data));

                var trans = TextFeaturizingEstimator.Create(env, MakeSentimentTextTransformArgs(), loader);

                // Train
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });

                var        cached     = new CacheDataView(env, trans, prefetch: null);
                var        trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                IPredictor predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));
                using (var ch = env.Start("Calibrator training"))
                {
                    predictor = CalibratorUtils.TrainCalibrator(env, ch, new PlattCalibratorTrainer(env), int.MaxValue, predictor, trainRoles);
                }

                var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true);

                var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                {
                });
                var metricsDict = evaluator.Evaluate(dataEval);

                var metrics = Legacy.Models.BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];

                var bindable  = ScoreUtils.GetSchemaBindableMapper(env, predictor, null);
                var mapper    = bindable.Bind(env, trainRoles.Schema);
                var newScorer = new BinaryClassifierScorer(env, new BinaryClassifierScorer.Arguments {
                    Threshold = 0.01f, ThresholdColumn = DefaultColumnNames.Probability
                },
                                                           scoreRoles.Data, mapper, trainRoles.Schema);

                dataEval = new RoleMappedData(newScorer, label: "Label", feature: "Features", opt: true);
                var newEvaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                {
                    Threshold = 0.01f, UseRawScoreThreshold = false
                });
                metricsDict = newEvaluator.Evaluate(dataEval);
                var newMetrics = Legacy.Models.BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];
            }
        }
Esempio n. 13
0
        private ITransformer BuildAndTrainUsingParams(ColumnEnum column)
        {
            List <string>            features        = new List <string>();
            TextFeaturizingEstimator textTransformer = null;
            EstimatorChain <ColumnConcatenatingTransformer> estimatorColumn = null;
            EstimatorChain <ITransformer> estimatorTransformer = null;

            if (_includeDay)
            {
                textTransformer = _mlContext.Transforms.Text.FeaturizeText("DayString", "Day");
                features.Add("DayString");
            }
            if (_includeMonth)
            {
                if (textTransformer != null)
                {
                    estimatorTransformer = textTransformer.Append(_mlContext.Transforms.Text.FeaturizeText("MonthString", "Month"));
                }
                else
                {
                    textTransformer = _mlContext.Transforms.Text.FeaturizeText("MonthString", "Month");
                }
                features.Add("MonthString");
            }
            if (_includeWeek)
            {
                features.Add("Week");
            }

            if (textTransformer == null)
            {
                var res = _mlContext.Transforms.Concatenate("Features", features.ToArray())
                          .Append(_mlContext.Transforms.CopyColumns("Label", System.Enum.GetName(typeof(ColumnEnum), column)))
                          .Append(_mlContext.Regression.Trainers.FastTreeTweedie());

                return(res.Fit(_trainData));
            }
            if (estimatorTransformer != null)
            {
                var res2 = estimatorTransformer.Append(_mlContext.Transforms.Concatenate("Features", features.ToArray()))
                           .Append(_mlContext.Transforms.CopyColumns("Label", System.Enum.GetName(typeof(ColumnEnum), column)))
                           .Append(_mlContext.Regression.Trainers.FastTreeTweedie());
                return(res2.Fit(_trainData));
            }
            var res3 = textTransformer.Append(_mlContext.Transforms.Concatenate("Features", features.ToArray()))
                       .Append(_mlContext.Transforms.CopyColumns("Label", System.Enum.GetName(typeof(ColumnEnum), column)))
                       .Append(_mlContext.Regression.Trainers.FastTreeTweedie());

            return(res3.Fit(_trainData));
        }
        public void TrainSentiment()
        {
            // Pipeline
            var arguments = new TextLoader.Arguments()
            {
                Column = new TextLoader.Column[]
                {
                    new TextLoader.Column()
                    {
                        Name   = "Label",
                        Source = new[] { new TextLoader.Range()
                                         {
                                             Min = 0, Max = 0
                                         } },
                        Type = DataKind.Num
                    },

                    new TextLoader.Column()
                    {
                        Name   = "SentimentText",
                        Source = new[] { new TextLoader.Range()
                                         {
                                             Min = 1, Max = 1
                                         } },
                        Type = DataKind.Text
                    }
                },
                HasHeader    = true,
                AllowQuoting = false,
                AllowSparse  = false
            };
            var loader = _env.Data.ReadFromTextFile(_sentimentDataPath, arguments);
            var text   = new TextFeaturizingEstimator(_env, "SentimentText", "WordEmbeddings", args =>
            {
                args.OutputTokens     = true;
                args.KeepPunctuations = false;
                args.UseStopRemover   = true;
                args.VectorNormalizer = TextFeaturizingEstimator.TextNormKind.None;
                args.UseCharExtractor = false;
                args.UseWordExtractor = false;
            }).Fit(loader).Transform(loader);
            var trans = new WordEmbeddingsExtractingEstimator(_env, "WordEmbeddings_TransformedText", "Features",
                                                              WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe).Fit(text).Transform(text);
            // Train
            var trainer   = _env.MulticlassClassification.Trainers.StochasticDualCoordinateAscent();
            var predicted = trainer.Fit(trans);

            _consumer.Consume(predicted);
        }
Esempio n. 15
0
        public void TrainSaveModelAndPredict()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var trans = TextFeaturizingEstimator.Create(env, MakeSentimentTextTransformArgs(), loader);

                // Train
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });

                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));

                PredictionEngine <SentimentData, SentimentPrediction> model;
                using (var file = env.CreateTempFile())
                {
                    // Save model.
                    var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                    using (var ch = env.Start("saving"))
                        TrainUtils.SaveModel(env, ch, file, predictor, scoreRoles);

                    // Load model.
                    using (var fs = file.OpenReadStream())
                        model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(fs);
                }

                // Take a couple examples out of the test data and run predictions on top.
                var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename)));
                var testData   = testLoader.AsEnumerable <SentimentData>(env, false);
                foreach (var input in testData.Take(5))
                {
                    var prediction = model.Predict(input);
                    // Verify that predictions match and scores are separated from zero.
                    Assert.Equal(input.Sentiment, prediction.Sentiment);
                    Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1);
                }
            }
        }
Esempio n. 16
0
        private (IEstimator <ITransformer>, IDataView) GetBinaryClassificationPipeline()
        {
            var data = new TextLoader(Env,
                                      new TextLoader.Arguments()
            {
                Separator = "\t",
                HasHeader = true,
                Columns   = new[]
                {
                    new TextLoader.Column("Label", DataKind.BL, 0),
                    new TextLoader.Column("SentimentText", DataKind.Text, 1)
                }
            }).Read(GetDataPath(TestDatasets.Sentiment.trainFilename));

            // Pipeline.
            var pipeline = new TextFeaturizingEstimator(Env, "Features", "SentimentText");

            return(pipeline, data);
        }
        private (IEstimator <ITransformer>, IDataView) GetBinaryClassificationPipeline()
        {
            var data = new TextLoader(Env,
                                      new TextLoader.Options()
            {
                AllowQuoting = true,
                Separator    = "\t",
                HasHeader    = true,
                Columns      = new[]
                {
                    new TextLoader.Column("Label", DataKind.Boolean, 0),
                    new TextLoader.Column("SentimentText", DataKind.String, 1),
                    new TextLoader.Column("LoggedIn", DataKind.Boolean, 2)
                }
            }).Load(GetDataPath(TestDatasets.Sentiment.trainFilename));

            // Pipeline.
            var pipeline = new TextFeaturizingEstimator(Env, "Features", "SentimentText");

            return(pipeline, data);
        }
        public void SimpleTrainAndPredict()
        {
            var dataset = TestDatasets.Sentiment;

            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(dataset.trainFilename)));

                var trans = TextFeaturizingEstimator.Create(env, MakeSentimentTextTransformArgs(), loader);

                // Train
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });

                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Create prediction engine and test predictions.
                var model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(scorer);

                // Take a couple examples out of the test data and run predictions on top.
                var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(dataset.testFilename)));
                var testData   = testLoader.AsEnumerable <SentimentData>(env, false);
                foreach (var input in testData.Take(5))
                {
                    var prediction = model.Predict(input);
                    // Verify that predictions match and scores are separated from zero.
                    Assert.Equal(input.Sentiment, prediction.Sentiment);
                    Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1);
                }
            }
        }
Esempio n. 19
0
        void Visibility()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline.
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var trans = TextFeaturizingEstimator.Create(env, MakeSentimentTextTransformArgs(false), loader);

                // In order to find out available column names, you can go through schema and check
                // column names and appropriate type for getter.
                for (int i = 0; i < trans.Schema.ColumnCount; i++)
                {
                    var columnName = trans.Schema.GetColumnName(i);
                    var columnType = trans.Schema.GetColumnType(i).RawType;
                }

                using (var cursor = trans.GetRowCursor(x => true))
                {
                    Assert.True(cursor.Schema.TryGetColumnIndex("SentimentText", out int textColumn));
                    Assert.True(cursor.Schema.TryGetColumnIndex("Features_TransformedText", out int transformedTextColumn));
                    Assert.True(cursor.Schema.TryGetColumnIndex("Features", out int featureColumn));

                    var originalTextGettter                          = cursor.GetGetter <ReadOnlyMemory <char> >(textColumn);
                    var transformedTextGettter                       = cursor.GetGetter <VBuffer <ReadOnlyMemory <char> > >(transformedTextColumn);
                    var featureGettter                               = cursor.GetGetter <VBuffer <float> >(featureColumn);
                    ReadOnlyMemory <char>            text            = default;
                    VBuffer <ReadOnlyMemory <char> > transformedText = default;
                    VBuffer <float> features                         = default;
                    while (cursor.MoveNext())
                    {
                        originalTextGettter(ref text);
                        transformedTextGettter(ref transformedText);
                        featureGettter(ref features);
                    }
                }
            }
        }
        void MultithreadedPrediction()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var trans = TextFeaturizingEstimator.Create(env, MakeSentimentTextTransformArgs(), loader);

                // Train
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });

                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Create prediction engine and test predictions.
                var model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(scorer);

                // Take a couple examples out of the test data and run predictions on top.
                var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename)));
                var testData   = testLoader.AsEnumerable <SentimentData>(env, false);

                Parallel.ForEach(testData, (input) =>
                {
                    lock (model)
                    {
                        var prediction = model.Predict(input);
                    }
                });
            }
        }
Esempio n. 21
0
                public override IEnumerable <SuggestedTransform> Apply(IntermediateColumn[] columns)
                {
                    var featureCols = new List <string>();

                    foreach (var column in columns)
                    {
                        if (!column.Type.ItemType().IsText() || column.Purpose != ColumnPurpose.TextFeature)
                        {
                            continue;
                        }

                        var columnDestSuffix = "_tf";
                        var columnNameSafe   = column.ColumnName;

                        string columnDestRenamed = $"{columnNameSafe}{columnDestSuffix}";

                        featureCols.Add(columnDestRenamed);
                        var input = new TextFeaturizingEstimator(Env, columnDestRenamed, columnNameSafe);
                        ColumnRoutingStructure.AnnotatedName[] columnsSource =
                        { new ColumnRoutingStructure.AnnotatedName {
                              IsNumeric = false, Name = columnNameSafe
                          } };
                        ColumnRoutingStructure.AnnotatedName[] columnsDest =
                        { new ColumnRoutingStructure.AnnotatedName {
                              IsNumeric = true, Name = columnDestRenamed
                          } };
                        var routingStructure = new ColumnRoutingStructure(columnsSource, columnsDest);
                        yield return(new SuggestedTransform(input, routingStructure));
                    }

                    // Concat text featurized columns into existing feature column, if transformed at least one column.
                    if (!ExcludeFeaturesConcatTransforms && featureCols.Count > 0)
                    {
                        yield return(InferenceHelpers.GetRemainingFeatures(featureCols, columns, GetType(), IncludeFeaturesOverride));

                        IncludeFeaturesOverride = true;
                    }
                }
        public void TrainWithInitialPredictor()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var trans     = TextFeaturizingEstimator.Create(env, MakeSentimentTextTransformArgs(), loader);
                var trainData = trans;

                var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
                // Train the first predictor.
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });
                var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
                var predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));

                // Train the second predictor on the same data.
                var secondTrainer  = new AveragedPerceptronTrainer(env, "Label", "Features");
                var finalPredictor = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: predictor));
            }
        }
Esempio n. 23
0
        public void TestWordEmbeddings()
        {
            var dataPath     = GetDataPath(ScenariosTests.SentimentDataPath);
            var testDataPath = GetDataPath(ScenariosTests.SentimentTestPath);

            var data = TextLoader.CreateReader(Env, ctx => (
                                                   label: ctx.LoadBool(0),
                                                   SentimentText: ctx.LoadText(1)), hasHeader: true)
                       .Read(dataPath);

            var dynamicData = TextFeaturizingEstimator.Create(Env, new TextFeaturizingEstimator.Arguments()
            {
                Column = new TextFeaturizingEstimator.Column
                {
                    Name   = "SentimentText_Features",
                    Source = new[] { "SentimentText" }
                },
                OutputTokens                 = true,
                KeepPunctuations             = false,
                UsePredefinedStopWordRemover = true,
                VectorNormalizer             = TextFeaturizingEstimator.TextNormKind.None,
                CharFeatureExtractor         = null,
                WordFeatureExtractor         = null,
            }, data.AsDynamic);

            var data2 = dynamicData.AssertStatic(Env, ctx => (
                                                     SentimentText_Features_TransformedText: ctx.Text.VarVector,
                                                     SentimentText: ctx.Text.Scalar,
                                                     label: ctx.Bool.Scalar));

            var est = data2.MakeNewEstimator()
                      .Append(row => row.SentimentText_Features_TransformedText.WordEmbeddings());

            TestEstimatorCore(est.AsDynamic, data2.AsDynamic, invalidInput: data.AsDynamic);
            Done();
        }
Esempio n. 24
0
        void CrossValidation()
        {
            var dataset = TestDatasets.Sentiment;

            int numFolds = 5;

            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline.
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(dataset.trainFilename)));

                var       text  = TextFeaturizingEstimator.Create(env, MakeSentimentTextTransformArgs(false), loader);
                IDataView trans = new GenerateNumberTransform(env, text, "StratificationColumn");
                // Train.
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads           = 1,
                    ConvergenceTolerance = 1f
                });

                var metrics = new List <BinaryClassificationMetrics>();
                for (int fold = 0; fold < numFolds; fold++)
                {
                    IDataView trainPipe = new RangeFilter(env, new RangeFilter.Arguments()
                    {
                        Column     = "StratificationColumn",
                        Min        = (Double)fold / numFolds,
                        Max        = (Double)(fold + 1) / numFolds,
                        Complement = true
                    }, trans);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = new RoleMappedData(trainPipe, label: "Label", feature: "Features");
                    // Auto-normalization.
                    NormalizeTransform.CreateIfNeeded(env, ref trainData, trainer);
                    var preCachedData = trainData;
                    // Auto-caching.
                    if (trainer.Info.WantCaching)
                    {
                        var prefetch  = trainData.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
                        var cacheView = new CacheDataView(env, trainData.Data, prefetch);
                        // Because the prefetching worked, we know that these are valid columns.
                        trainData = new RoleMappedData(cacheView, trainData.Schema.GetColumnRoleNames());
                    }

                    var       predictor = trainer.Train(new Runtime.TrainContext(trainData));
                    IDataView testPipe  = new RangeFilter(env, new RangeFilter.Arguments()
                    {
                        Column     = "StratificationColumn",
                        Min        = (Double)fold / numFolds,
                        Max        = (Double)(fold + 1) / numFolds,
                        Complement = false
                    }, trans);
                    testPipe = new OpaqueDataView(testPipe);
                    var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, preCachedData.Data, testPipe, trainPipe);

                    var testRoles = new RoleMappedData(pipe, trainData.Schema.GetColumnRoleNames());

                    IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema);

                    BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                    {
                    });
                    var dataEval    = new RoleMappedData(scorer, testRoles.Schema.GetColumnRoleNames(), opt: true);
                    var dict        = eval.Evaluate(dataEval);
                    var foldMetrics = BinaryClassificationMetrics.FromMetrics(env, dict["OverallMetrics"], dict["ConfusionMatrix"]);
                    metrics.Add(foldMetrics.Single());
                }
            }
        }
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordEmbedding()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new ConsoleEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.Num, 0),
                        new TextLoader.Column("SentimentText", DataKind.Text, 1)
                    }
                }, new MultiFileSource(dataPath));

                var text = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments()
                {
                    Column = new TextFeaturizingEstimator.Column
                    {
                        Name   = "WordEmbeddings",
                        Source = new[] { "SentimentText" }
                    },
                    OutputTokens         = true,
                    KeepPunctuations     = false,
                    StopWordsRemover     = new PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextFeaturizingEstimator.TextNormKind.None,
                    CharFeatureExtractor = null,
                    WordFeatureExtractor = null,
                },
                                                           loader);

                var trans = WordEmbeddingsExtractingTransformer.Create(env, new WordEmbeddingsExtractingTransformer.Arguments()
                {
                    Column = new WordEmbeddingsExtractingTransformer.Column[1]
                    {
                        new WordEmbeddingsExtractingTransformer.Column
                        {
                            Name   = "Features",
                            Source = "WordEmbeddings_TransformedText"
                        }
                    },
                    ModelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe,
                }, text);
                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numLeaves: 5, numTrees: 5, minDatapointsInLeaves: 2);

                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);
                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);

                // SSWE is a simple word embedding model + we train on a really small dataset, so metrics are not great.
                Assert.Equal(.6667, metrics.Accuracy, 4);
                Assert.Equal(.71, metrics.Auc, 1);
                Assert.Equal(.58, metrics.Auprc, 2);
                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment);
                Assert.True(predictions.ElementAt(1).Sentiment);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }
Esempio n. 26
0
        public void TrainSentiment()
        {
            var env = new MLContext(seed: 1);
            // Pipeline
            var arguments = new TextLoader.Arguments()
            {
                Column = new TextLoader.Column[]
                {
                    new TextLoader.Column()
                    {
                        Name   = "Label",
                        Source = new[] { new TextLoader.Range()
                                         {
                                             Min = 0, Max = 0
                                         } },
                        Type = DataKind.Num
                    },

                    new TextLoader.Column()
                    {
                        Name   = "SentimentText",
                        Source = new[] { new TextLoader.Range()
                                         {
                                             Min = 1, Max = 1
                                         } },
                        Type = DataKind.Text
                    }
                },
                HasHeader    = true,
                AllowQuoting = false,
                AllowSparse  = false
            };
            var loader = env.Data.ReadFromTextFile(_sentimentDataPath, arguments);

            var text = TextFeaturizingEstimator.Create(env,
                                                       new TextFeaturizingEstimator.Arguments()
            {
                Column = new TextFeaturizingEstimator.Column
                {
                    Name   = "WordEmbeddings",
                    Source = new[] { "SentimentText" }
                },
                OutputTokens                 = true,
                KeepPunctuations             = false,
                UsePredefinedStopWordRemover = true,
                VectorNormalizer             = TextFeaturizingEstimator.TextNormKind.None,
                CharFeatureExtractor         = null,
                WordFeatureExtractor         = null,
            }, loader);

            var trans = WordEmbeddingsExtractingTransformer.Create(env,
                                                                   new WordEmbeddingsExtractingTransformer.Arguments()
            {
                Column = new WordEmbeddingsExtractingTransformer.Column[1]
                {
                    new WordEmbeddingsExtractingTransformer.Column
                    {
                        Name   = "Features",
                        Source = "WordEmbeddings_TransformedText"
                    }
                },
                ModelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe,
            }, text);

            // Train
            var trainer   = new SdcaMultiClassTrainer(env, "Label", "Features", maxIterations: 20);
            var predicted = trainer.Fit(trans);

            _consumer.Consume(predicted);
        }
Esempio n. 27
0
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new ConsoleEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.Num, 0),
                        new TextLoader.Column("SentimentText", DataKind.Text, 1)
                    }
                }, new MultiFileSource(dataPath));

                var trans = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments()
                {
                    Column = new TextFeaturizingEstimator.Column
                    {
                        Name   = "Features",
                        Source = new[] { "SentimentText" }
                    },
                    OutputTokens         = true,
                    KeepPunctuations     = false,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextFeaturizingEstimator.TextNormKind.L2,
                    CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 3, AllLengths = false
                    },
                    WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 2, AllLengths = true
                    },
                },
                                                            loader);

                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features,
                                                                      numLeaves: 5, numTrees: 5, minDocumentsInLeafs: 2);

                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);

                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);
                ValidateBinaryMetrics(metrics);

                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment);
                Assert.True(predictions.ElementAt(1).Sentiment);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }
        public void TrainSentiment()
        {
            using (var env = new ConsoleEnvironment(seed: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    AllowQuoting = false,
                    AllowSparse  = false,
                    Separator    = "tab",
                    HasHeader    = true,
                    Column       = new[]
                    {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 0, Max = 0
                                              } },
                            Type = DataKind.Num
                        },

                        new TextLoader.Column()
                        {
                            Name   = "SentimentText",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 1, Max = 1
                                              } },
                            Type = DataKind.Text
                        }
                    }
                }, new MultiFileSource(_sentimentDataPath));

                var text = TextFeaturizingEstimator.Create(env,
                                                           new TextFeaturizingEstimator.Arguments()
                {
                    Column = new TextFeaturizingEstimator.Column
                    {
                        Name   = "WordEmbeddings",
                        Source = new[] { "SentimentText" }
                    },
                    OutputTokens         = true,
                    KeepPunctuations     = false,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextFeaturizingEstimator.TextNormKind.None,
                    CharFeatureExtractor = null,
                    WordFeatureExtractor = null,
                }, loader);

                var trans = WordEmbeddingsTransform.Create(env,
                                                           new WordEmbeddingsTransform.Arguments()
                {
                    Column = new WordEmbeddingsTransform.Column[1]
                    {
                        new WordEmbeddingsTransform.Column
                        {
                            Name   = "Features",
                            Source = "WordEmbeddings_TransformedText"
                        }
                    },
                    ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe,
                }, text);

                // Train
                var trainer    = new SdcaMultiClassTrainer(env, "Features", "Label", maxIterations: 20);
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");

                var predicted = trainer.Train(trainRoles);
                _consumer.Consume(predicted);
            }
        }
Esempio n. 29
0
        public ML <T, T2> Run()
        {
            IDataView trainingDataView = mlContext.Data.LoadFromEnumerable(this.data);

            DataOperationsCatalog.TrainTestData dataSplit = mlContext.Data
                                                            .TrainTestSplit(trainingDataView, testFraction: 0.2);

            switch (this.type)
            {
            case MLType.TextFeaturizingEstimator:
            {
                TextFeaturizingEstimator dataProcessPipeline = mlContext.Transforms.Text
                                                               .FeaturizeText(outputColumnName: "Features", inputColumnName: this.inputName);

                SdcaLogisticRegressionBinaryTrainer sdcaRegressionTrainer = mlContext.BinaryClassification.Trainers
                                                                            .SdcaLogisticRegression(labelColumnName: this.labelName, featureColumnName: "Features");

                EstimatorChain <BinaryPredictionTransformer <CalibratedModelParametersBase <LinearBinaryModelParameters, PlattCalibrator> > > trainingPipeline = dataProcessPipeline.Append(sdcaRegressionTrainer);

                trainedModel = trainingPipeline.Fit(dataSplit.TrainSet);
                mlContext.Model.Save(trainedModel, dataSplit.TrainSet.Schema, this.modelName);
                IDataView testSetTransform = trainedModel.Transform(dataSplit.TestSet);

                this.modelMetrics = mlContext.BinaryClassification
                                    .Evaluate(data: testSetTransform,
                                              labelColumnName: this.labelName,
                                              scoreColumnName: this.scoreName);
                break;
            }

            case MLType.LightGbm:
            {
                var fields = this.thisType
                             .GetType()
                             .GetProperties(BindingFlags.Public | BindingFlags.Instance)
                             .Select(p => p.Name);
                //mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")

                var featurePipeline = this.isFeaturesIncluded
                        ? null
                        : mlContext.Transforms.Concatenate("Features", fields.ToArray());

                var trainer = mlContext.Regression.Trainers
                              .LightGbm(new LightGbmRegressionTrainer.Options()
                    {
                        NumberOfIterations                = 100,
                        LearningRate                      = 0.3227682f,
                        NumberOfLeaves                    = 55,
                        MinimumExampleCountPerLeaf        = 10,
                        UseCategoricalSplit               = false,
                        HandleMissingValue                = true,
                        UseZeroAsMissingValue             = false,
                        MinimumExampleCountPerGroup       = 50,
                        MaximumCategoricalSplitPointCount = 32,
                        CategoricalSmoothing              = 20,
                        L2CategoricalRegularization       = 5,
                        Booster = new GradientBooster.Options()
                        {
                            L2Regularization = 0, L1Regularization = 0.5
                        },
                        LabelColumnName   = this.labelName,
                        FeatureColumnName = "Features"
                    });

                var pipeline2 = featurePipeline == null ? null : featurePipeline.Append(trainer);

                if (pipeline2 == null)
                {
                    trainedModel = trainer.Fit(dataSplit.TrainSet);
                    mlContext.Model.Save(trainedModel, dataSplit.TrainSet.Schema, this.modelName);

                    IDataView testSetTransform = trainedModel.Transform(dataSplit.TestSet);

                    var crossValidationResults = mlContext.Regression
                                                 .CrossValidate(trainingDataView, trainer, numberOfFolds: 5, labelColumnName: this.labelName);
                }
                else
                {
                    trainedModel = pipeline2.Fit(dataSplit.TrainSet);
                    mlContext.Model.Save(trainedModel, dataSplit.TrainSet.Schema, this.modelName);

                    IDataView testSetTransform = trainedModel.Transform(dataSplit.TestSet);

                    var crossValidationResults = mlContext.Regression
                                                 .CrossValidate(trainingDataView,
                                                                pipeline2,
                                                                numberOfFolds: 5,
                                                                labelColumnName: this.labelName);
                }

                //this.modelMetrics = mlContext.Regression
                //.Evaluate(data: testSetTransform,
                //          labelColumnName: this.labelName,
                //          scoreColumnName: this.scoreName);
                break;
            }
            }

            //var msg = $"Area Under Curve: {modelMetrics.AreaUnderRocCurve:P2}{Environment.NewLine}" +
            //    $"Area Under Precision Recall Curve: {modelMetrics.AreaUnderPrecisionRecallCurve:P2}" +
            //    $"{Environment.NewLine}" +
            //    $"Accuracy: {modelMetrics.Accuracy:P2}{Environment.NewLine}" +
            //    $"F1Score: {modelMetrics.F1Score:P2}{Environment.NewLine}" +
            //    $"Positive Recall: {modelMetrics.PositiveRecall:#.##}{Environment.NewLine}" +
            //    $"Negative Recall: {modelMetrics.NegativeRecall:#.##}{Environment.NewLine}";

            this.isTaught = true;
            return(this);
        }