public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new ConsoleEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.Num, 0),
                        new TextLoader.Column("SentimentText", DataKind.Text, 1)
                    }
                }, new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, new TextTransform.Arguments()
                {
                    Column = new TextTransform.Column
                    {
                        Name   = "Features",
                        Source = new[] { "SentimentText" }
                    },
                    KeepDiacritics       = false,
                    KeepPunctuations     = false,
                    TextCase             = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
                    OutputTokens         = true,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextTransform.TextNormKind.L2,
                    CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 3, AllLengths = false
                    },
                    WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 2, AllLengths = true
                    },
                },
                                                 loader);

                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()
                {
                    NumLeaves           = 5,
                    NumTrees            = 5,
                    MinDocumentsInLeafs = 2
                });

                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);

                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);
                ValidateBinaryMetrics(metrics);

                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment.IsTrue);
                Assert.True(predictions.ElementAt(1).Sentiment.IsTrue);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }
示例#2
0
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordEmbedding()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new ConsoleEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.Num, 0),
                        new TextLoader.Column("SentimentText", DataKind.Text, 1)
                    }
                }, new MultiFileSource(dataPath));

                var text = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments()
                {
                    Column = new TextFeaturizingEstimator.Column
                    {
                        Name   = "WordEmbeddings",
                        Source = new[] { "SentimentText" }
                    },
                    OutputTokens         = true,
                    KeepPunctuations     = false,
                    StopWordsRemover     = new PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextFeaturizingEstimator.TextNormKind.None,
                    CharFeatureExtractor = null,
                    WordFeatureExtractor = null,
                },
                                                           loader);

                var trans = WordEmbeddingsTransform.Create(env, new WordEmbeddingsTransform.Arguments()
                {
                    Column = new WordEmbeddingsTransform.Column[1]
                    {
                        new WordEmbeddingsTransform.Column
                        {
                            Name   = "Features",
                            Source = "WordEmbeddings_TransformedText"
                        }
                    },
                    ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe,
                }, text);
                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numLeaves: 5, numTrees: 5, minDocumentsInLeafs: 2);

                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);
                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);

                // SSWE is a simple word embedding model + we train on a really small dataset, so metrics are not great.
                Assert.Equal(.6667, metrics.Accuracy, 4);
                Assert.Equal(.71, metrics.Auc, 1);
                Assert.Equal(.58, metrics.Auprc, 2);
                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment);
                Assert.True(predictions.ElementAt(1).Sentiment);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }