예제 #1
0
        /// <summary>
        /// Read generic model from file.
        /// </summary>
        /// <typeparam name="TInput">Type for incoming data</typeparam>
        /// <typeparam name="TOutput">Type for output data</typeparam>
        /// <param name="stream">Stream with model</param>
        /// <returns>Model</returns>
        public static Task <PredictionModel <TInput, TOutput> > ReadAsync <TInput, TOutput>(Stream stream)
            where TInput : class
            where TOutput : class, new()
        {
            if (stream == null)
            {
                throw new ArgumentNullException(nameof(stream));
            }

            var environment = new MLContext();

            AssemblyRegistration.RegisterAssemblies(environment);

            BatchPredictionEngine <TInput, TOutput> predictor =
                environment.CreateBatchPredictionEngine <TInput, TOutput>(stream);

            return(Task.FromResult(new PredictionModel <TInput, TOutput>(predictor, stream)));
        }
예제 #2
0
        private static BatchPredictionEngine <MLSankakuPost, MLSankakuPostLikeagePrediciton> Train(long userId)
        {
            var context = new DiscordContext();

            MLContext mlContext = new MLContext(seed: 0);

            var data = GetPosts(userId);

            var schemaDef = SchemaDefinition.Create(typeof(MLSankakuPost));

            var trainData = mlContext.CreateStreamingDataView <MLSankakuPost>(data, schemaDef);

            var pipeline = mlContext.Regression.Trainers.FastTree("Label", "Features", numLeaves: 50, numTrees: 50, minDatapointsInLeaves: 20);

            var model = pipeline.Fit(trainData);

            return(mlContext.CreateBatchPredictionEngine <MLSankakuPost, MLSankakuPostLikeagePrediciton>(trainData, true,
                                                                                                         schemaDef));

            //return model.MakePredictionFunction<MLSankakuPost, MLSankakuPost>(mlContext);
        }
예제 #3
0
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordEmbedding()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            var env = new MLContext(seed: 1, conc: 1);
            // Pipeline
            var loader = env.Data.ReadFromTextFile(dataPath,
                                                   columns: new[]
            {
                new TextLoader.Column("Label", DataKind.Num, 0),
                new TextLoader.Column("SentimentText", DataKind.Text, 1)
            },
                                                   hasHeader: true
                                                   );

            var text = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments()
            {
                Column = new TextFeaturizingEstimator.Column
                {
                    Name   = "WordEmbeddings",
                    Source = new[] { "SentimentText" }
                },
                OutputTokens                 = true,
                KeepPunctuations             = false,
                UsePredefinedStopWordRemover = true,
                VectorNormalizer             = TextFeaturizingEstimator.TextNormKind.None,
                CharFeatureExtractor         = null,
                WordFeatureExtractor         = null,
            },
                                                       loader);

            var trans = WordEmbeddingsExtractingTransformer.Create(env, new WordEmbeddingsExtractingTransformer.Arguments()
            {
                Column = new WordEmbeddingsExtractingTransformer.Column[1]
                {
                    new WordEmbeddingsExtractingTransformer.Column
                    {
                        Name   = "Features",
                        Source = "WordEmbeddings_TransformedText"
                    }
                },
                ModelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe,
            }, text);
            // Train
            var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numLeaves: 5, numTrees: 5, minDatapointsInLeaves: 2);

            var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
            var pred       = trainer.Train(trainRoles);
            // Get scorer and evaluate the predictions from test data
            IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
            var metrics = EvaluateBinary(env, testDataScorer);

            // SSWE is a simple word embedding model + we train on a really small dataset, so metrics are not great.
            Assert.Equal(.6667, metrics.Accuracy, 4);
            Assert.Equal(.71, metrics.Auc, 1);
            Assert.Equal(.58, metrics.Auprc, 2);
            // Create prediction engine and test predictions
            var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
            var sentiments  = GetTestData();
            var predictions = model.Predict(sentiments, false);

            Assert.Equal(2, predictions.Count());
            Assert.True(predictions.ElementAt(0).Sentiment);
            Assert.True(predictions.ElementAt(1).Sentiment);

            // Get feature importance based on feature gain during training
            var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);

            Assert.Equal(1.0, (double)summary[0].Value, 1);
        }
예제 #4
0
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            var env = new MLContext(seed: 1, conc: 1);
            // Pipeline
            var loader = env.Data.ReadFromTextFile(dataPath,
                                                   columns: new[]
            {
                new TextLoader.Column("Label", DataKind.Num, 0),
                new TextLoader.Column("SentimentText", DataKind.Text, 1)
            },
                                                   hasHeader: true
                                                   );

            var trans = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments()
            {
                Column = new TextFeaturizingEstimator.Column
                {
                    Name   = "Features",
                    Source = new[] { "SentimentText" }
                },
                OutputTokens                 = true,
                KeepPunctuations             = false,
                UsePredefinedStopWordRemover = true,
                VectorNormalizer             = TextFeaturizingEstimator.TextNormKind.L2,
                CharFeatureExtractor         = new NgramExtractorTransform.NgramExtractorArguments()
                {
                    NgramLength = 3, AllLengths = false
                },
                WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                {
                    NgramLength = 2, AllLengths = true
                },
            },
                                                        loader);

            // Train
            var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features,
                                                                  numLeaves: 5, numTrees: 5, minDatapointsInLeaves: 2);

            var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
            var pred       = trainer.Train(trainRoles);

            // Get scorer and evaluate the predictions from test data
            IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
            var metrics = EvaluateBinary(env, testDataScorer);

            ValidateBinaryMetrics(metrics);

            // Create prediction engine and test predictions
            var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
            var sentiments  = GetTestData();
            var predictions = model.Predict(sentiments, false);

            Assert.Equal(2, predictions.Count());
            Assert.True(predictions.ElementAt(0).Sentiment);
            Assert.True(predictions.ElementAt(1).Sentiment);

            // Get feature importance based on feature gain during training
            var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);

            Assert.Equal(1.0, (double)summary[0].Value, 1);
        }