/// <summary>
        /// Read generic model from file.
        /// </summary>
        /// <typeparam name="TInput">Type for incoming data</typeparam>
        /// <typeparam name="TOutput">Type for output data</typeparam>
        /// <param name="stream">Stream with model</param>
        /// <returns>Model</returns>
        public static Task <PredictionModel <TInput, TOutput> > ReadAsync <TInput, TOutput>(Stream stream)
            where TInput : class
            where TOutput : class, new()
        {
            if (stream == null)
            {
                throw new ArgumentNullException(nameof(stream));
            }

            using (var environment = new TlcEnvironment())
            {
                BatchPredictionEngine <TInput, TOutput> predictor =
                    environment.CreateBatchPredictionEngine <TInput, TOutput>(stream);

                return(Task.FromResult(new PredictionModel <TInput, TOutput>(predictor, stream)));
            }
        }
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 0, Max = 0
                                              } },
                            Type = DataKind.Num
                        },

                        new TextLoader.Column()
                        {
                            Name   = "SentimentText",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 1, Max = 1
                                              } },
                            Type = DataKind.Text
                        }
                    }
                }, new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, new TextTransform.Arguments()
                {
                    Column = new TextTransform.Column
                    {
                        Name   = "Features",
                        Source = new[] { "SentimentText" }
                    },
                    KeepDiacritics       = false,
                    KeepPunctuations     = false,
                    TextCase             = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
                    OutputTokens         = true,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextTransform.TextNormKind.L2,
                    CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 3, AllLengths = false
                    },
                    WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 2, AllLengths = true
                    },
                },
                                                 loader);

                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()
                {
                    NumLeaves           = 5,
                    NumTrees            = 5,
                    MinDocumentsInLeafs = 2
                });

                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);

                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);
                ValidateBinaryMetrics(metrics);

                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment.IsTrue);
                Assert.True(predictions.ElementAt(1).Sentiment.IsTrue);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordEmbedding()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.Num, 0),
                        new TextLoader.Column("SentimentText", DataKind.Text, 1)
                    }
                }, new MultiFileSource(dataPath));

                var text = TextTransform.Create(env, new TextTransform.Arguments()
                {
                    Column = new TextTransform.Column
                    {
                        Name   = "WordEmbeddings",
                        Source = new[] { "SentimentText" }
                    },
                    KeepDiacritics       = false,
                    KeepPunctuations     = false,
                    TextCase             = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
                    OutputTokens         = true,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextTransform.TextNormKind.None,
                    CharFeatureExtractor = null,
                    WordFeatureExtractor = null,
                },
                                                loader);

                var trans = new WordEmbeddingsTransform(env, new WordEmbeddingsTransform.Arguments()
                {
                    Column = new WordEmbeddingsTransform.Column[1]
                    {
                        new WordEmbeddingsTransform.Column
                        {
                            Name   = "Features",
                            Source = "WordEmbeddings_TransformedText"
                        }
                    },
                    ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe,
                }, text);
                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()
                {
                    NumLeaves           = 5,
                    NumTrees            = 5,
                    MinDocumentsInLeafs = 2
                });

                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);
                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);

                // SSWE is a simple word embedding model + we train on a really small dataset, so metrics are not great.
                Assert.Equal(.6667, metrics.Accuracy, 4);
                Assert.Equal(.71, metrics.Auc, 1);
                Assert.Equal(.58, metrics.Auprc, 2);
                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment.IsTrue);
                Assert.True(predictions.ElementAt(1).Sentiment.IsTrue);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }