public void TrainWithValidationSet() { using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); var trainData = trans; // Apply the same transformations on the validation set. // Sadly, there is no way to easily apply the same loader to different data, so we either have // to create another loader, or to save the loader to model file and then reload. // A new one is not always feasible, but this time it is. var validLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))); var validData = ApplyTransformUtils.ApplyAllTransformsToData(env, trainData, validLoader); // Cache both datasets. var cachedTrain = new CacheDataView(env, trainData, prefetch: null); var cachedValid = new CacheDataView(env, validData, prefetch: null); // Train. var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numTrees: 3); var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features"); var validRoles = new RoleMappedData(cachedValid, label: "Label", feature: "Features"); trainer.Train(new Runtime.TrainContext(trainRoles, validRoles)); } }
void FileBasedSavingOfData() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); var saver = new BinarySaver(env, new BinarySaver.Arguments()); using (var ch = env.Start("SaveData")) using (var file = env.CreateOutputFile("i.idv")) { DataSaverUtils.SaveDataView(ch, saver, trans, file); } var binData = new BinaryLoader(env, new BinaryLoader.Arguments(), new MultiFileSource("i.idv")); var trainRoles = new RoleMappedData(binData, label: "Label", feature: "Features"); var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); DeleteOutputPath("i.idv"); } }
public void AutoNormalizationAndCaching() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline. var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader); // Train. var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1, ConvergenceTolerance = 1f }); // Auto-caching. IDataView trainData = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, trans, prefetch: null) : trans; var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); // Auto-normalization. NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer); var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); } }
public void TrainWithInitialPredictor() { var dataPath = GetDataPath(SentimentDataPath); using (var env = new TlcEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); var trainData = trans; var cachedTrain = new CacheDataView(env, trainData, prefetch: null); // Train the first predictor. var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features"); var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); // Train the second predictor on the same data. var secondTrainer = new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments()); var finalPredictor = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: predictor)); } }
public TransformWrapper Fit(IDataView input) { var xf = TextTransform.Create(_env, _args, input); var empty = new EmptyDataView(_env, input.Schema); var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input); return(new TransformWrapper(_env, chunk)); }
void ReconfigurablePrediction() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new TlcEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); // Train var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); IPredictor predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); using (var ch = env.Start("Calibrator training")) { predictor = CalibratorUtils.TrainCalibrator(env, ch, new PlattCalibratorTrainer(env), int.MaxValue, predictor, trainRoles); } var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true); var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { }); var metricsDict = evaluator.Evaluate(dataEval); var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0]; var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor, null); var mapper = bindable.Bind(env, trainRoles.Schema); var newScorer = new BinaryClassifierScorer(env, new BinaryClassifierScorer.Arguments { Threshold = 0.01f, ThresholdColumn = DefaultColumnNames.Probability }, scoreRoles.Data, mapper, trainRoles.Schema); dataEval = new RoleMappedData(newScorer, label: "Label", feature: "Features", opt: true); var newEvaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { Threshold = 0.01f, UseRawScoreThreshold = false }); metricsDict = newEvaluator.Evaluate(dataEval); var newMetrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0]; } }
public void TrainSaveModelAndPredict() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); // Train var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); PredictionEngine <SentimentData, SentimentPrediction> model; using (var file = env.CreateTempFile()) { // Save model. var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); using (var ch = env.Start("saving")) TrainUtils.SaveModel(env, ch, file, predictor, scoreRoles); // Load model. using (var fs = file.OpenReadStream()) model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(fs); } // Take a couple examples out of the test data and run predictions on top. var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); var testData = testLoader.AsEnumerable <SentimentData>(env, false); foreach (var input in testData.Take(5)) { var prediction = model.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } }
public void Evaluation() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new TlcEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); // Train var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); // Create prediction engine and test predictions. var model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(scorer); // Take a couple examples out of the test data and run predictions on top. var testLoader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); var testData = testLoader.AsEnumerable <SentimentData>(env, false); var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true); var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { }); var metricsDict = evaluator.Evaluate(dataEval); var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0]; } }
void Visibility() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new TlcEnvironment(seed: 1, conc: 1)) { // Pipeline. var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader); // In order to find out available column names, you can go through schema and check // column names and appropriate type for getter. for (int i = 0; i < trans.Schema.ColumnCount; i++) { var columnName = trans.Schema.GetColumnName(i); var columnType = trans.Schema.GetColumnType(i).RawType; } using (var cursor = trans.GetRowCursor(x => true)) { Assert.True(cursor.Schema.TryGetColumnIndex("SentimentText", out int textColumn)); Assert.True(cursor.Schema.TryGetColumnIndex("Features_TransformedText", out int transformedTextColumn)); Assert.True(cursor.Schema.TryGetColumnIndex("Features", out int featureColumn)); var originalTextGettter = cursor.GetGetter <DvText>(textColumn); var transformedTextGettter = cursor.GetGetter <VBuffer <DvText> >(transformedTextColumn); var featureGettter = cursor.GetGetter <VBuffer <float> >(featureColumn); DvText text = default; VBuffer <DvText> transformedText = default; VBuffer <float> features = default; while (cursor.MoveNext()) { originalTextGettter(ref text); transformedTextGettter(ref transformedText); featureGettter(ref features); } } } }
void MultithreadedPrediction() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); // Train var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); // Create prediction engine and test predictions. var model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(scorer); // Take a couple examples out of the test data and run predictions on top. var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); var testData = testLoader.AsEnumerable <SentimentData>(env, false); Parallel.ForEach(testData, (input) => { lock (model) { var prediction = model.Predict(input); } }); } }
public void SimpleTrainAndPredict() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new TlcEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); // Train var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); // Create prediction engine and test predictions. var model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(scorer); // Take a couple examples out of the test data and run predictions on top. var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); var testData = testLoader.AsEnumerable <SentimentData>(env, false); foreach (var input in testData.Take(5)) { var prediction = model.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } }
public void TestWordEmbeddings() { var dataPath = GetDataPath(ScenariosTests.SentimentDataPath); var testDataPath = GetDataPath(ScenariosTests.SentimentTestPath); var data = TextLoader.CreateReader(Env, ctx => ( label: ctx.LoadBool(0), SentimentText: ctx.LoadText(1)), hasHeader: true) .Read(new MultiFileSource(dataPath)); var dynamicData = TextTransform.Create(Env, new TextTransform.Arguments() { Column = new TextTransform.Column { Name = "SentimentText_Features", Source = new[] { "SentimentText" } }, KeepDiacritics = false, KeepPunctuations = false, TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, OutputTokens = true, StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), VectorNormalizer = TextTransform.TextNormKind.None, CharFeatureExtractor = null, WordFeatureExtractor = null, }, data.AsDynamic); var data2 = dynamicData.AssertStatic(Env, ctx => ( SentimentText_Features_TransformedText: ctx.Text.VarVector, SentimentText: ctx.Text.Scalar, label: ctx.Bool.Scalar)); var est = data2.MakeNewEstimator() .Append(row => row.SentimentText_Features_TransformedText.WordEmbeddings()); TestEstimatorCore(est.AsDynamic, data2.AsDynamic, invalidInput: data.AsDynamic); Done(); }
void CrossValidation() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); int numFolds = 5; using (var env = new TlcEnvironment(seed: 1, conc: 1)) { // Pipeline. var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var text = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader); IDataView trans = new GenerateNumberTransform(env, text, "StratificationColumn"); // Train. var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1, ConvergenceTolerance = 1f }); var metrics = new List <BinaryClassificationMetrics>(); for (int fold = 0; fold < numFolds; fold++) { IDataView trainPipe = new RangeFilter(env, new RangeFilter.Arguments() { Column = "StratificationColumn", Min = (Double)fold / numFolds, Max = (Double)(fold + 1) / numFolds, Complement = true }, trans); trainPipe = new OpaqueDataView(trainPipe); var trainData = new RoleMappedData(trainPipe, label: "Label", feature: "Features"); // Auto-normalization. NormalizeTransform.CreateIfNeeded(env, ref trainData, trainer); var preCachedData = trainData; // Auto-caching. if (trainer.Info.WantCaching) { var prefetch = trainData.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray(); var cacheView = new CacheDataView(env, trainData.Data, prefetch); // Because the prefetching worked, we know that these are valid columns. trainData = new RoleMappedData(cacheView, trainData.Schema.GetColumnRoleNames()); } var predictor = trainer.Train(new Runtime.TrainContext(trainData)); IDataView testPipe = new RangeFilter(env, new RangeFilter.Arguments() { Column = "StratificationColumn", Min = (Double)fold / numFolds, Max = (Double)(fold + 1) / numFolds, Complement = false }, trans); testPipe = new OpaqueDataView(testPipe); var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, preCachedData.Data, testPipe, trainPipe); var testRoles = new RoleMappedData(pipe, trainData.Schema.GetColumnRoleNames()); IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema); BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { }); var dataEval = new RoleMappedData(scorer, testRoles.Schema.GetColumnRoleNames(), opt: true); var dict = eval.Evaluate(dataEval); var foldMetrics = BinaryClassificationMetrics.FromMetrics(env, dict["OverallMetrics"], dict["ConfusionMatrix"]); metrics.Add(foldMetrics.Single()); } } }
public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new TlcEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = new TextLoader(env, new TextLoader.Arguments() { Separator = "tab", HasHeader = true, Column = new[] { new TextLoader.Column() { Name = "Label", Source = new [] { new TextLoader.Range() { Min = 0, Max = 0 } }, Type = DataKind.Num }, new TextLoader.Column() { Name = "SentimentText", Source = new [] { new TextLoader.Range() { Min = 1, Max = 1 } }, Type = DataKind.Text } } }, new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, new TextTransform.Arguments() { Column = new TextTransform.Column { Name = "Features", Source = new[] { "SentimentText" } }, KeepDiacritics = false, KeepPunctuations = false, TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, OutputTokens = true, StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), VectorNormalizer = TextTransform.TextNormKind.L2, CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 3, AllLengths = false }, WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 2, AllLengths = true }, }, loader); // Train var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); var pred = trainer.Train(trainRoles); // Get scorer and evaluate the predictions from test data IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); var metrics = EvaluateBinary(env, testDataScorer); ValidateBinaryMetrics(metrics); // Create prediction engine and test predictions var model = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer); var sentiments = GetTestData(); var predictions = model.Predict(sentiments, false); Assert.Equal(2, predictions.Count()); Assert.True(predictions.ElementAt(0).Sentiment.IsTrue); Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); // Get feature importance based on feature gain during training var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); Assert.Equal(1.0, (double)summary[0].Value, 1); } }
public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordEmbedding() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = TextLoader.ReadFile(env, new TextLoader.Arguments() { Separator = "tab", HasHeader = true, Column = new[] { new TextLoader.Column("Label", DataKind.Num, 0), new TextLoader.Column("SentimentText", DataKind.Text, 1) } }, new MultiFileSource(dataPath)); var text = TextTransform.Create(env, new TextTransform.Arguments() { Column = new TextTransform.Column { Name = "WordEmbeddings", Source = new[] { "SentimentText" } }, KeepDiacritics = false, KeepPunctuations = false, TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, OutputTokens = true, StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), VectorNormalizer = TextTransform.TextNormKind.None, CharFeatureExtractor = null, WordFeatureExtractor = null, }, loader); var trans = WordEmbeddingsTransform.Create(env, new WordEmbeddingsTransform.Arguments() { Column = new WordEmbeddingsTransform.Column[1] { new WordEmbeddingsTransform.Column { Name = "Features", Source = "WordEmbeddings_TransformedText" } }, ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe, }, text); // Train var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, advancedSettings: s => { s.NumLeaves = 5; s.NumTrees = 5; s.MinDocumentsInLeafs = 2; }); var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); var pred = trainer.Train(trainRoles); // Get scorer and evaluate the predictions from test data IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); var metrics = EvaluateBinary(env, testDataScorer); // SSWE is a simple word embedding model + we train on a really small dataset, so metrics are not great. Assert.Equal(.6667, metrics.Accuracy, 4); Assert.Equal(.71, metrics.Auc, 1); Assert.Equal(.58, metrics.Auprc, 2); // Create prediction engine and test predictions var model = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer); var sentiments = GetTestData(); var predictions = model.Predict(sentiments, false); Assert.Equal(2, predictions.Count()); Assert.True(predictions.ElementAt(0).Sentiment); Assert.True(predictions.ElementAt(1).Sentiment); // Get feature importance based on feature gain during training var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); Assert.Equal(1.0, (double)summary[0].Value, 1); } }
public void TrainSentiment() { using (var env = new ConsoleEnvironment(seed: 1)) { // Pipeline var loader = TextLoader.ReadFile(env, new TextLoader.Arguments() { AllowQuoting = false, AllowSparse = false, Separator = "tab", HasHeader = true, Column = new[] { new TextLoader.Column() { Name = "Label", Source = new [] { new TextLoader.Range() { Min = 0, Max = 0 } }, Type = DataKind.Num }, new TextLoader.Column() { Name = "SentimentText", Source = new [] { new TextLoader.Range() { Min = 1, Max = 1 } }, Type = DataKind.Text } } }, new MultiFileSource(_sentimentDataPath)); var text = TextTransform.Create(env, new TextTransform.Arguments() { Column = new TextTransform.Column { Name = "WordEmbeddings", Source = new[] { "SentimentText" } }, KeepDiacritics = false, KeepPunctuations = false, TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, OutputTokens = true, StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), VectorNormalizer = TextTransform.TextNormKind.None, CharFeatureExtractor = null, WordFeatureExtractor = null, }, loader); var trans = new WordEmbeddingsTransform(env, new WordEmbeddingsTransform.Arguments() { Column = new WordEmbeddingsTransform.Column[1] { new WordEmbeddingsTransform.Column { Name = "Features", Source = "WordEmbeddings_TransformedText" } }, ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe, }, text); // Train var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { MaxIterations = 20 }); var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); var predicted = trainer.Train(trainRoles); _consumer.Consume(predicted); } }