/// <summary>
        /// FastTree <see cref="BinaryClassificationCatalog"/> extension method.
        /// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
        /// </summary>
        /// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
        /// <param name="label">The label column.</param>
        /// <param name="features">The features column.</param>
        /// <param name="weights">The optional weights column.</param>
        /// <param name="options">Algorithm advanced settings.</param>
        /// <param name="onFit">A delegate that is called every time the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
        /// the linear model that was trained. Note that this action cannot change the result in any way;
        /// it is only a way for the caller to be informed about what was learnt.</param>
        /// <returns>The set of output columns including in order the predicted binary classification score (which will range
        /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        ///  [!code-csharp[FastTree](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs)]
        /// ]]></format>
        /// </example>
        public static (Scalar <float> score, Scalar <float> probability, Scalar <bool> predictedLabel) FastTree(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
                                                                                                                Scalar <bool> label, Vector <float> features, Scalar <float> weights,
                                                                                                                FastTreeBinaryClassificationTrainer.Options options,
                                                                                                                Action <CalibratedModelParametersBase <FastTreeBinaryModelParameters, PlattCalibrator> > onFit = null)
        {
            Contracts.CheckValueOrNull(options);
            CheckUserValues(label, features, weights, onFit);

            var rec = new TrainerEstimatorReconciler.BinaryClassifier(
                (env, labelName, featuresName, weightsName) =>
            {
                options.LabelColumnName         = labelName;
                options.FeatureColumnName       = featuresName;
                options.ExampleWeightColumnName = weightsName;

                var trainer = new FastTreeBinaryClassificationTrainer(env, options);

                if (onFit != null)
                {
                    return(trainer.WithOnFitDelegate(trans => onFit(trans.Model)));
                }
                else
                {
                    return(trainer);
                }
            }, label, features, weights);

            return(rec.Output);
        }
Example #2
0
        public void TrainWithValidationSet()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var trans     = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);
                var trainData = trans;

                // Apply the same transformations on the validation set.
                // Sadly, there is no way to easily apply the same loader to different data, so we either have
                // to create another loader, or to save the loader to model file and then reload.

                // A new one is not always feasible, but this time it is.
                var validLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename)));
                var validData   = ApplyTransformUtils.ApplyAllTransformsToData(env, trainData, validLoader);

                // Cache both datasets.
                var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
                var cachedValid = new CacheDataView(env, validData, prefetch: null);

                // Train.
                var trainer    = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numTrees: 3);
                var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
                var validRoles = new RoleMappedData(cachedValid, label: "Label", feature: "Features");
                trainer.Train(new Runtime.TrainContext(trainRoles, validRoles));
            }
        }
        /// <summary>
        /// FastTree <see cref="BinaryClassificationContext"/> extension method.
        /// Predict a target using a decision tree binary classificaiton model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
        /// </summary>
        /// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
        /// <param name="label">The label column.</param>
        /// <param name="features">The features column.</param>
        /// <param name="weights">The optional weights column.</param>
        /// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
        /// <param name="numLeaves">The maximum number of leaves per decision tree.</param>
        /// <param name="minDatapointsInLeaves">The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data.</param>
        /// <param name="learningRate">The learning rate.</param>
        /// <param name="advancedSettings">Algorithm advanced settings.</param>
        /// <param name="onFit">A delegate that is called every time the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
        /// the linear model that was trained. Note that this action cannot change the result in any way;
        /// it is only a way for the caller to be informed about what was learnt.</param>
        /// <returns>The set of output columns including in order the predicted binary classification score (which will range
        /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        ///  [!code-csharp[FastTree](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs)]
        /// ]]></format>
        /// </example>
        public static (Scalar <float> score, Scalar <float> probability, Scalar <bool> predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
                                                                                                                Scalar <bool> label, Vector <float> features, Scalar <float> weights = null,
                                                                                                                int numLeaves             = Defaults.NumLeaves,
                                                                                                                int numTrees              = Defaults.NumTrees,
                                                                                                                int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves,
                                                                                                                double learningRate       = Defaults.LearningRates,
                                                                                                                Action <FastTreeBinaryClassificationTrainer.Arguments> advancedSettings = null,
                                                                                                                Action <IPredictorWithFeatureWeights <float> > onFit = null)
        {
            CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings, onFit);

            var rec = new TrainerEstimatorReconciler.BinaryClassifier(
                (env, labelName, featuresName, weightsName) =>
            {
                var trainer = new FastTreeBinaryClassificationTrainer(env, labelName, featuresName, weightsName, numLeaves,
                                                                      numTrees, minDatapointsInLeaves, learningRate, advancedSettings);

                if (onFit != null)
                {
                    return(trainer.WithOnFitDelegate(trans => onFit(trans.Model)));
                }
                else
                {
                    return(trainer);
                }
            }, label, features, weights);

            return(rec.Output);
        }
        /// <summary>
        /// FastTree <see cref="BinaryClassificationCatalog"/> extension method.
        /// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
        /// </summary>
        /// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
        /// <param name="label">The label column.</param>
        /// <param name="features">The features column.</param>
        /// <param name="weights">The optional weights column.</param>
        /// <param name="numberOfTrees">Total number of decision trees to create in the ensemble.</param>
        /// <param name="numberOfLeaves">The maximum number of leaves per decision tree.</param>
        /// <param name="minimumExampleCountPerLeaf">The minimal number of data points allowed in a leaf of the tree, out of the subsampled data.</param>
        /// <param name="learningRate">The learning rate.</param>
        /// <param name="onFit">A delegate that is called every time the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
        /// the linear model that was trained. Note that this action cannot change the result in any way;
        /// it is only a way for the caller to be informed about what was learnt.</param>
        /// <returns>The set of output columns including in order the predicted binary classification score (which will range
        /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        ///  [!code-csharp[FastTree](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs)]
        /// ]]></format>
        /// </example>
        public static (Scalar <float> score, Scalar <float> probability, Scalar <bool> predictedLabel) FastTree(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
                                                                                                                Scalar <bool> label, Vector <float> features, Scalar <float> weights = null,
                                                                                                                int numberOfLeaves             = Defaults.NumberOfLeaves,
                                                                                                                int numberOfTrees              = Defaults.NumberOfTrees,
                                                                                                                int minimumExampleCountPerLeaf = Defaults.MinimumExampleCountPerLeaf,
                                                                                                                double learningRate            = Defaults.LearningRate,
                                                                                                                Action <CalibratedModelParametersBase <FastTreeBinaryModelParameters, PlattCalibrator> > onFit = null)
        {
            CheckUserValues(label, features, weights, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf, learningRate, onFit);

            var rec = new TrainerEstimatorReconciler.BinaryClassifier(
                (env, labelName, featuresName, weightsName) =>
            {
                var trainer = new FastTreeBinaryClassificationTrainer(env, labelName, featuresName, weightsName, numberOfLeaves,
                                                                      numberOfTrees, minimumExampleCountPerLeaf, learningRate);

                if (onFit != null)
                {
                    return(trainer.WithOnFitDelegate(trans => onFit(trans.Model)));
                }
                else
                {
                    return(trainer);
                }
            }, label, features, weights);

            return(rec.Output);
        }
Example #5
0
        /// <summary>
        /// FastTree <see cref="BinaryClassificationCatalog"/> extension method.
        /// Predict a target using a decision tree binary classificaiton model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
        /// </summary>
        /// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
        /// <param name="label">The label column.</param>
        /// <param name="features">The features column.</param>
        /// <param name="weights">The optional weights column.</param>
        /// <param name="options">Algorithm advanced settings.</param>
        /// <param name="onFit">A delegate that is called every time the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
        /// the linear model that was trained. Note that this action cannot change the result in any way;
        /// it is only a way for the caller to be informed about what was learnt.</param>
        /// <returns>The set of output columns including in order the predicted binary classification score (which will range
        /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        ///  [!code-csharp[FastTree](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs)]
        /// ]]></format>
        /// </example>
        public static (Scalar <float> score, Scalar <float> probability, Scalar <bool> predictedLabel) FastTree(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
                                                                                                                Scalar <bool> label, Vector <float> features, Scalar <float> weights,
                                                                                                                FastTreeBinaryClassificationTrainer.Options options,
                                                                                                                Action <IPredictorWithFeatureWeights <float> > onFit = null)
        {
            Contracts.CheckValueOrNull(options);
            CheckUserValues(label, features, weights, onFit);

            var rec = new TrainerEstimatorReconciler.BinaryClassifier(
                (env, labelName, featuresName, weightsName) =>
            {
                options.LabelColumn   = labelName;
                options.FeatureColumn = featuresName;
                options.WeightColumn  = weightsName != null ? Optional <string> .Explicit(weightsName) : Optional <string> .Implicit(DefaultColumnNames.Weight);

                var trainer = new FastTreeBinaryClassificationTrainer(env, options);

                if (onFit != null)
                {
                    return(trainer.WithOnFitDelegate(trans => onFit(trans.Model)));
                }
                else
                {
                    return(trainer);
                }
            }, label, features, weights);

            return(rec.Output);
        }
        /// <summary>
        /// FastTree <see cref="BinaryClassificationContext"/> extension method.
        /// Predict a target using a decision tree binary classificaiton model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
        /// </summary>
        /// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
        /// <param name="label">The label column.</param>
        /// <param name="features">The features column.</param>
        /// <param name="weights">The optional weights column.</param>
        /// <param name="options">Algorithm advanced settings.</param>
        /// <param name="onFit">A delegate that is called every time the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
        /// the linear model that was trained. Note that this action cannot change the result in any way;
        /// it is only a way for the caller to be informed about what was learnt.</param>
        /// <returns>The set of output columns including in order the predicted binary classification score (which will range
        /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        ///  [!code-csharp[FastTree](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs)]
        /// ]]></format>
        /// </example>
        public static (Scalar <float> score, Scalar <float> probability, Scalar <bool> predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
                                                                                                                Scalar <bool> label, Vector <float> features, Scalar <float> weights,
                                                                                                                FastTreeBinaryClassificationTrainer.Options options,
                                                                                                                Action <IPredictorWithFeatureWeights <float> > onFit = null)
        {
            Contracts.CheckValueOrNull(options);
            CheckUserValues(label, features, weights, onFit);

            var rec = new TrainerEstimatorReconciler.BinaryClassifier(
                (env, labelName, featuresName, weightsName) =>
            {
                var trainer = new FastTreeBinaryClassificationTrainer(env, options);

                if (onFit != null)
                {
                    return(trainer.WithOnFitDelegate(trans => onFit(trans.Model)));
                }
                else
                {
                    return(trainer);
                }
            }, label, features, weights);

            return(rec.Output);
        }
Example #7
0
        public void FastTreeBinaryEstimator()
        {
            var(pipe, dataView) = GetBinaryClassificationPipeline();

            var trainer = new FastTreeBinaryClassificationTrainer(Env, "Label", "Features", numTrees: 10, numLeaves: 5, advancedSettings: s =>
            {
                s.NumThreads = 1;
            });
            var pipeWithTrainer = pipe.Append(trainer);

            TestEstimatorCore(pipeWithTrainer, dataView);

            var transformedDataView = pipe.Fit(dataView).Transform(dataView);
            var model = trainer.Train(transformedDataView, transformedDataView);

            Done();
        }
Example #8
0
        public void IntrospectiveTraining()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var words = WordBagTransform.Create(env, new WordBagTransform.Arguments()
                {
                    NgramLength = 1,
                    Column      = new[] { new WordBagTransform.Column()
                                          {
                                              Name = "Tokenize", Source = new[] { "SentimentText" }
                                          } }
                }, loader);

                var lda = new LdaTransform(env, new LdaTransform.Arguments()
                {
                    NumTopic      = 10,
                    NumIterations = 3,
                    NumThreads    = 1,
                    Column        = new[] { new LdaTransform.Column {
                                                Source = "Tokenize", Name = "Features"
                                            } }
                }, words);
                var trainData = lda;

                var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
                // Train the first predictor.
                var linearTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });
                var             trainRoles      = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
                var             linearPredictor = linearTrainer.Train(new Runtime.TrainContext(trainRoles));
                VBuffer <float> weights         = default;
                linearPredictor.GetFeatureWeights(ref weights);

                var topicSummary = lda.GetTopicSummary();
                var treeTrainer  = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numTrees: 2);
                var ftPredictor  = treeTrainer.Train(new Runtime.TrainContext(trainRoles));
                FastTreeBinaryPredictor treePredictor;
                if (ftPredictor is CalibratedPredictorBase calibrator)
                {
                    treePredictor = (FastTreeBinaryPredictor)calibrator.SubPredictor;
                }
                else
                {
                    treePredictor = (FastTreeBinaryPredictor)ftPredictor;
                }
                var featureNameCollection = FeatureNameCollection.Create(trainRoles.Schema);
                foreach (var tree in treePredictor.GetTrees())
                {
                    var lteChild = tree.LteChild;
                    var gteChild = tree.GtChild;
                    // Get nodes.
                    for (var i = 0; i < tree.NumNodes; i++)
                    {
                        var node              = tree.GetNode(i, false, featureNameCollection);
                        var gainValue         = GetValue <double>(node.KeyValues, "GainValue");
                        var splitGain         = GetValue <double>(node.KeyValues, "SplitGain");
                        var featureName       = GetValue <string>(node.KeyValues, "SplitName");
                        var previousLeafValue = GetValue <double>(node.KeyValues, "PreviousLeafValue");
                        var threshold         = GetValue <string>(node.KeyValues, "Threshold").Split(new[] { ' ' }, 2)[1];
                        var nodeIndex         = i;
                    }
                    // Get leaves.
                    for (var i = 0; i < tree.NumLeaves; i++)
                    {
                        var node      = tree.GetNode(i, true, featureNameCollection);
                        var leafValue = GetValue <double>(node.KeyValues, "LeafValue");
                        var extras    = GetValue <string>(node.KeyValues, "Extras");
                        var nodeIndex = ~i;
                    }
                }
            }
        }
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordEmbedding()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new ConsoleEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.Num, 0),
                        new TextLoader.Column("SentimentText", DataKind.Text, 1)
                    }
                }, new MultiFileSource(dataPath));

                var text = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments()
                {
                    Column = new TextFeaturizingEstimator.Column
                    {
                        Name   = "WordEmbeddings",
                        Source = new[] { "SentimentText" }
                    },
                    OutputTokens         = true,
                    KeepPunctuations     = false,
                    StopWordsRemover     = new PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextFeaturizingEstimator.TextNormKind.None,
                    CharFeatureExtractor = null,
                    WordFeatureExtractor = null,
                },
                                                           loader);

                var trans = WordEmbeddingsExtractingTransformer.Create(env, new WordEmbeddingsExtractingTransformer.Arguments()
                {
                    Column = new WordEmbeddingsExtractingTransformer.Column[1]
                    {
                        new WordEmbeddingsExtractingTransformer.Column
                        {
                            Name   = "Features",
                            Source = "WordEmbeddings_TransformedText"
                        }
                    },
                    ModelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe,
                }, text);
                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numLeaves: 5, numTrees: 5, minDatapointsInLeaves: 2);

                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);
                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);

                // SSWE is a simple word embedding model + we train on a really small dataset, so metrics are not great.
                Assert.Equal(.6667, metrics.Accuracy, 4);
                Assert.Equal(.71, metrics.Auc, 1);
                Assert.Equal(.58, metrics.Auprc, 2);
                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment);
                Assert.True(predictions.ElementAt(1).Sentiment);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 0, Max = 0
                                              } },
                            Type = DataKind.Num
                        },

                        new TextLoader.Column()
                        {
                            Name   = "SentimentText",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 1, Max = 1
                                              } },
                            Type = DataKind.Text
                        }
                    }
                }, new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, new TextTransform.Arguments()
                {
                    Column = new TextTransform.Column
                    {
                        Name   = "Features",
                        Source = new[] { "SentimentText" }
                    },
                    KeepDiacritics       = false,
                    KeepPunctuations     = false,
                    TextCase             = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
                    OutputTokens         = true,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextTransform.TextNormKind.L2,
                    CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 3, AllLengths = false
                    },
                    WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 2, AllLengths = true
                    },
                },
                                                 loader);

                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()
                {
                    NumLeaves           = 5,
                    NumTrees            = 5,
                    MinDocumentsInLeafs = 2
                });

                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);

                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);
                ValidateBinaryMetrics(metrics);

                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment.IsTrue);
                Assert.True(predictions.ElementAt(1).Sentiment.IsTrue);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }
Example #11
0
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            var env = new MLContext(seed: 1, conc: 1);
            // Pipeline
            var loader = env.Data.ReadFromTextFile(dataPath,
                                                   columns: new[]
            {
                new TextLoader.Column("Label", DataKind.Num, 0),
                new TextLoader.Column("SentimentText", DataKind.Text, 1)
            },
                                                   hasHeader: true
                                                   );

            var trans = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments()
            {
                Column = new TextFeaturizingEstimator.Column
                {
                    Name   = "Features",
                    Source = new[] { "SentimentText" }
                },
                OutputTokens                 = true,
                KeepPunctuations             = false,
                UsePredefinedStopWordRemover = true,
                VectorNormalizer             = TextFeaturizingEstimator.TextNormKind.L2,
                CharFeatureExtractor         = new NgramExtractorTransform.NgramExtractorArguments()
                {
                    NgramLength = 3, AllLengths = false
                },
                WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                {
                    NgramLength = 2, AllLengths = true
                },
            },
                                                        loader);

            // Train
            var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features,
                                                                  numLeaves: 5, numTrees: 5, minDatapointsInLeaves: 2);

            var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
            var pred       = trainer.Train(trainRoles);

            // Get scorer and evaluate the predictions from test data
            IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
            var metrics = EvaluateBinary(env, testDataScorer);

            ValidateBinaryMetrics(metrics);

            // Create prediction engine and test predictions
            var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
            var sentiments  = GetTestData();
            var predictions = model.Predict(sentiments, false);

            Assert.Equal(2, predictions.Count());
            Assert.True(predictions.ElementAt(0).Sentiment);
            Assert.True(predictions.ElementAt(1).Sentiment);

            // Get feature importance based on feature gain during training
            var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);

            Assert.Equal(1.0, (double)summary[0].Value, 1);
        }