コード例 #1
0
        public void TrainWithValidationSet()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var trans     = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
                var trainData = trans;

                // Apply the same transformations on the validation set.
                // Sadly, there is no way to easily apply the same loader to different data, so we either have
                // to create another loader, or to save the loader to model file and then reload.

                // A new one is not always feasible, but this time it is.
                var validLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename)));
                var validData   = ApplyTransformUtils.ApplyAllTransformsToData(env, trainData, validLoader);

                // Cache both datasets.
                var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
                var cachedValid = new CacheDataView(env, validData, prefetch: null);

                // Train.
                var trainer    = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numTrees: 3);
                var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
                var validRoles = new RoleMappedData(cachedValid, label: "Label", feature: "Features");
                trainer.Train(new Runtime.TrainContext(trainRoles, validRoles));
            }
        }
コード例 #2
0
        void FileBasedSavingOfData()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
                var saver = new BinarySaver(env, new BinarySaver.Arguments());
                using (var ch = env.Start("SaveData"))
                    using (var file = env.CreateOutputFile("i.idv"))
                    {
                        DataSaverUtils.SaveDataView(ch, saver, trans, file);
                    }

                var binData    = new BinaryLoader(env, new BinaryLoader.Arguments(), new MultiFileSource("i.idv"));
                var trainRoles = new RoleMappedData(binData, label: "Label", feature: "Features");
                var trainer    = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                DeleteOutputPath("i.idv");
            }
        }
コード例 #3
0
        public void AutoNormalizationAndCaching()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline.
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader);

                // Train.
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads           = 1,
                    ConvergenceTolerance = 1f
                });

                // Auto-caching.
                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, trans, prefetch: null) : trans;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
            }
        }
コード例 #4
0
        public void TrainWithInitialPredictor()
        {
            var dataPath = GetDataPath(SentimentDataPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var trans     = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
                var trainData = trans;

                var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
                // Train the first predictor.
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });
                var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
                var predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));

                // Train the second predictor on the same data.
                var secondTrainer  = new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments());
                var finalPredictor = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: predictor));
            }
        }
コード例 #5
0
        public TransformWrapper Fit(IDataView input)
        {
            var xf    = TextTransform.Create(_env, _args, input);
            var empty = new EmptyDataView(_env, input.Schema);
            var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);

            return(new TransformWrapper(_env, chunk));
        }
コード例 #6
0
        void ReconfigurablePrediction()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);

                // Train
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });

                var        cached     = new CacheDataView(env, trans, prefetch: null);
                var        trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                IPredictor predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));
                using (var ch = env.Start("Calibrator training"))
                {
                    predictor = CalibratorUtils.TrainCalibrator(env, ch, new PlattCalibratorTrainer(env), int.MaxValue, predictor, trainRoles);
                }

                var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true);

                var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                {
                });
                var metricsDict = evaluator.Evaluate(dataEval);

                var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];

                var bindable  = ScoreUtils.GetSchemaBindableMapper(env, predictor, null);
                var mapper    = bindable.Bind(env, trainRoles.Schema);
                var newScorer = new BinaryClassifierScorer(env, new BinaryClassifierScorer.Arguments {
                    Threshold = 0.01f, ThresholdColumn = DefaultColumnNames.Probability
                },
                                                           scoreRoles.Data, mapper, trainRoles.Schema);

                dataEval = new RoleMappedData(newScorer, label: "Label", feature: "Features", opt: true);
                var newEvaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                {
                    Threshold = 0.01f, UseRawScoreThreshold = false
                });
                metricsDict = newEvaluator.Evaluate(dataEval);
                var newMetrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];
            }
        }
コード例 #7
0
        public void TrainSaveModelAndPredict()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);

                // Train
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });

                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));

                PredictionEngine <SentimentData, SentimentPrediction> model;
                using (var file = env.CreateTempFile())
                {
                    // Save model.
                    var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                    using (var ch = env.Start("saving"))
                        TrainUtils.SaveModel(env, ch, file, predictor, scoreRoles);

                    // Load model.
                    using (var fs = file.OpenReadStream())
                        model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(fs);
                }

                // Take a couple examples out of the test data and run predictions on top.
                var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath)));
                var testData   = testLoader.AsEnumerable <SentimentData>(env, false);
                foreach (var input in testData.Take(5))
                {
                    var prediction = model.Predict(input);
                    // Verify that predictions match and scores are separated from zero.
                    Assert.Equal(input.Sentiment, prediction.Sentiment);
                    Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1);
                }
            }
        }
コード例 #8
0
        public void Evaluation()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);

                // Train
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });

                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));
                var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Create prediction engine and test predictions.
                var model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(scorer);

                // Take a couple examples out of the test data and run predictions on top.
                var testLoader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath)));
                var testData   = testLoader.AsEnumerable <SentimentData>(env, false);

                var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true);

                var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                {
                });
                var metricsDict = evaluator.Evaluate(dataEval);

                var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];
            }
        }
コード例 #9
0
        void Visibility()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline.
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader);

                // In order to find out available column names, you can go through schema and check
                // column names and appropriate type for getter.
                for (int i = 0; i < trans.Schema.ColumnCount; i++)
                {
                    var columnName = trans.Schema.GetColumnName(i);
                    var columnType = trans.Schema.GetColumnType(i).RawType;
                }

                using (var cursor = trans.GetRowCursor(x => true))
                {
                    Assert.True(cursor.Schema.TryGetColumnIndex("SentimentText", out int textColumn));
                    Assert.True(cursor.Schema.TryGetColumnIndex("Features_TransformedText", out int transformedTextColumn));
                    Assert.True(cursor.Schema.TryGetColumnIndex("Features", out int featureColumn));

                    var              originalTextGettter    = cursor.GetGetter <DvText>(textColumn);
                    var              transformedTextGettter = cursor.GetGetter <VBuffer <DvText> >(transformedTextColumn);
                    var              featureGettter         = cursor.GetGetter <VBuffer <float> >(featureColumn);
                    DvText           text            = default;
                    VBuffer <DvText> transformedText = default;
                    VBuffer <float>  features        = default;
                    while (cursor.MoveNext())
                    {
                        originalTextGettter(ref text);
                        transformedTextGettter(ref transformedText);
                        featureGettter(ref features);
                    }
                }
            }
        }
コード例 #10
0
        void MultithreadedPrediction()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);

                // Train
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });

                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Create prediction engine and test predictions.
                var model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(scorer);

                // Take a couple examples out of the test data and run predictions on top.
                var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath)));
                var testData   = testLoader.AsEnumerable <SentimentData>(env, false);

                Parallel.ForEach(testData, (input) =>
                {
                    lock (model)
                    {
                        var prediction = model.Predict(input);
                    }
                });
            }
        }
コード例 #11
0
        public void SimpleTrainAndPredict()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);

                // Train
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });

                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Create prediction engine and test predictions.
                var model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(scorer);

                // Take a couple examples out of the test data and run predictions on top.
                var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath)));
                var testData   = testLoader.AsEnumerable <SentimentData>(env, false);
                foreach (var input in testData.Take(5))
                {
                    var prediction = model.Predict(input);
                    // Verify that predictions match and scores are separated from zero.
                    Assert.Equal(input.Sentiment, prediction.Sentiment);
                    Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1);
                }
            }
        }
コード例 #12
0
        public void TestWordEmbeddings()
        {
            var dataPath     = GetDataPath(ScenariosTests.SentimentDataPath);
            var testDataPath = GetDataPath(ScenariosTests.SentimentTestPath);

            var data = TextLoader.CreateReader(Env, ctx => (
                                                   label: ctx.LoadBool(0),
                                                   SentimentText: ctx.LoadText(1)), hasHeader: true)
                       .Read(new MultiFileSource(dataPath));

            var dynamicData = TextTransform.Create(Env, new TextTransform.Arguments()
            {
                Column = new TextTransform.Column
                {
                    Name   = "SentimentText_Features",
                    Source = new[] { "SentimentText" }
                },
                KeepDiacritics       = false,
                KeepPunctuations     = false,
                TextCase             = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
                OutputTokens         = true,
                StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                VectorNormalizer     = TextTransform.TextNormKind.None,
                CharFeatureExtractor = null,
                WordFeatureExtractor = null,
            }, data.AsDynamic);

            var data2 = dynamicData.AssertStatic(Env, ctx => (
                                                     SentimentText_Features_TransformedText: ctx.Text.VarVector,
                                                     SentimentText: ctx.Text.Scalar,
                                                     label: ctx.Bool.Scalar));

            var est = data2.MakeNewEstimator()
                      .Append(row => row.SentimentText_Features_TransformedText.WordEmbeddings());

            TestEstimatorCore(est.AsDynamic, data2.AsDynamic, invalidInput: data.AsDynamic);
            Done();
        }
コード例 #13
0
        void CrossValidation()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            int numFolds = 5;

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline.
                var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var       text  = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader);
                IDataView trans = new GenerateNumberTransform(env, text, "StratificationColumn");
                // Train.
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads           = 1,
                    ConvergenceTolerance = 1f
                });


                var metrics = new List <BinaryClassificationMetrics>();
                for (int fold = 0; fold < numFolds; fold++)
                {
                    IDataView trainPipe = new RangeFilter(env, new RangeFilter.Arguments()
                    {
                        Column     = "StratificationColumn",
                        Min        = (Double)fold / numFolds,
                        Max        = (Double)(fold + 1) / numFolds,
                        Complement = true
                    }, trans);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = new RoleMappedData(trainPipe, label: "Label", feature: "Features");
                    // Auto-normalization.
                    NormalizeTransform.CreateIfNeeded(env, ref trainData, trainer);
                    var preCachedData = trainData;
                    // Auto-caching.
                    if (trainer.Info.WantCaching)
                    {
                        var prefetch  = trainData.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
                        var cacheView = new CacheDataView(env, trainData.Data, prefetch);
                        // Because the prefetching worked, we know that these are valid columns.
                        trainData = new RoleMappedData(cacheView, trainData.Schema.GetColumnRoleNames());
                    }

                    var       predictor = trainer.Train(new Runtime.TrainContext(trainData));
                    IDataView testPipe  = new RangeFilter(env, new RangeFilter.Arguments()
                    {
                        Column     = "StratificationColumn",
                        Min        = (Double)fold / numFolds,
                        Max        = (Double)(fold + 1) / numFolds,
                        Complement = false
                    }, trans);
                    testPipe = new OpaqueDataView(testPipe);
                    var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, preCachedData.Data, testPipe, trainPipe);

                    var testRoles = new RoleMappedData(pipe, trainData.Schema.GetColumnRoleNames());

                    IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema);

                    BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                    {
                    });
                    var dataEval    = new RoleMappedData(scorer, testRoles.Schema.GetColumnRoleNames(), opt: true);
                    var dict        = eval.Evaluate(dataEval);
                    var foldMetrics = BinaryClassificationMetrics.FromMetrics(env, dict["OverallMetrics"], dict["ConfusionMatrix"]);
                    metrics.Add(foldMetrics.Single());
                }
            }
        }
コード例 #14
0
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 0, Max = 0
                                              } },
                            Type = DataKind.Num
                        },

                        new TextLoader.Column()
                        {
                            Name   = "SentimentText",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 1, Max = 1
                                              } },
                            Type = DataKind.Text
                        }
                    }
                }, new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, new TextTransform.Arguments()
                {
                    Column = new TextTransform.Column
                    {
                        Name   = "Features",
                        Source = new[] { "SentimentText" }
                    },
                    KeepDiacritics       = false,
                    KeepPunctuations     = false,
                    TextCase             = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
                    OutputTokens         = true,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextTransform.TextNormKind.L2,
                    CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 3, AllLengths = false
                    },
                    WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 2, AllLengths = true
                    },
                },
                                                 loader);

                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()
                {
                    NumLeaves           = 5,
                    NumTrees            = 5,
                    MinDocumentsInLeafs = 2
                });

                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);

                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);
                ValidateBinaryMetrics(metrics);

                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment.IsTrue);
                Assert.True(predictions.ElementAt(1).Sentiment.IsTrue);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }
コード例 #15
0
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordEmbedding()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new ConsoleEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.Num, 0),
                        new TextLoader.Column("SentimentText", DataKind.Text, 1)
                    }
                }, new MultiFileSource(dataPath));

                var text = TextTransform.Create(env, new TextTransform.Arguments()
                {
                    Column = new TextTransform.Column
                    {
                        Name   = "WordEmbeddings",
                        Source = new[] { "SentimentText" }
                    },
                    KeepDiacritics       = false,
                    KeepPunctuations     = false,
                    TextCase             = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
                    OutputTokens         = true,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextTransform.TextNormKind.None,
                    CharFeatureExtractor = null,
                    WordFeatureExtractor = null,
                },
                                                loader);

                var trans = WordEmbeddingsTransform.Create(env, new WordEmbeddingsTransform.Arguments()
                {
                    Column = new WordEmbeddingsTransform.Column[1]
                    {
                        new WordEmbeddingsTransform.Column
                        {
                            Name   = "Features",
                            Source = "WordEmbeddings_TransformedText"
                        }
                    },
                    ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe,
                }, text);
                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, advancedSettings: s =>
                {
                    s.NumLeaves           = 5;
                    s.NumTrees            = 5;
                    s.MinDocumentsInLeafs = 2;
                });

                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);
                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);

                // SSWE is a simple word embedding model + we train on a really small dataset, so metrics are not great.
                Assert.Equal(.6667, metrics.Accuracy, 4);
                Assert.Equal(.71, metrics.Auc, 1);
                Assert.Equal(.58, metrics.Auprc, 2);
                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment);
                Assert.True(predictions.ElementAt(1).Sentiment);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }
コード例 #16
0
        public void TrainSentiment()
        {
            using (var env = new ConsoleEnvironment(seed: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    AllowQuoting = false,
                    AllowSparse  = false,
                    Separator    = "tab",
                    HasHeader    = true,
                    Column       = new[]
                    {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 0, Max = 0
                                              } },
                            Type = DataKind.Num
                        },

                        new TextLoader.Column()
                        {
                            Name   = "SentimentText",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 1, Max = 1
                                              } },
                            Type = DataKind.Text
                        }
                    }
                }, new MultiFileSource(_sentimentDataPath));

                var text = TextTransform.Create(env,
                                                new TextTransform.Arguments()
                {
                    Column = new TextTransform.Column
                    {
                        Name   = "WordEmbeddings",
                        Source = new[] { "SentimentText" }
                    },
                    KeepDiacritics       = false,
                    KeepPunctuations     = false,
                    TextCase             = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
                    OutputTokens         = true,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextTransform.TextNormKind.None,
                    CharFeatureExtractor = null,
                    WordFeatureExtractor = null,
                }, loader);

                var trans = new WordEmbeddingsTransform(env,
                                                        new WordEmbeddingsTransform.Arguments()
                {
                    Column = new WordEmbeddingsTransform.Column[1]
                    {
                        new WordEmbeddingsTransform.Column
                        {
                            Name   = "Features",
                            Source = "WordEmbeddings_TransformedText"
                        }
                    },
                    ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe,
                }, text);

                // Train
                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments()
                {
                    MaxIterations = 20
                });
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");

                var predicted = trainer.Train(trainRoles);
                _consumer.Consume(predicted);
            }
        }