Esempio n. 1
0
        void DecomposableTrainAndPredict()
        {
            using (var env = new LocalEnvironment()
                             .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog
            {
                var loader  = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename)));
                var term    = TermTransform.Create(env, loader, "Label");
                var concat  = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
                var trainer = new SdcaMultiClassTrainer(env, "Features", "Label", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Cut out term transform from pipeline.
                var newScorer  = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term);
                var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(newScorer);
                var model      = env.CreatePredictionEngine <IrisDataNoLabel, IrisPrediction>(keyToValue);

                var testData = loader.AsEnumerable <IrisDataNoLabel>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == "Iris-setosa");
                }
            }
        }
        public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
        {
            string dataPath     = GetDataPath("iris.txt");
            string testDataPath = dataPath;

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                {
                    HasHeader = false,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.R4, 0),
                        new TextLoader.Column("SepalLength", DataKind.R4, 1),
                        new TextLoader.Column("SepalWidth", DataKind.R4, 2),
                        new TextLoader.Column("PetalLength", DataKind.R4, 3),
                        new TextLoader.Column("PetalWidth", DataKind.R4, 4)
                    }
                }, new MultiFileSource(dataPath));

                IDataTransform trans = new ConcatTransform(env, loader, "Features",
                                                           "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");

                // Normalizer is not automatically added though the trainer has 'NormalizeFeatures' On/Auto
                trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "Features");

                // Train
                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments()
                {
                    NumThreads = 1
                });

                // Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto
                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);

                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = Evaluate(env, testDataScorer);
                CompareMatrics(metrics);

                // Create prediction engine and test predictions
                var model = env.CreatePredictionEngine <IrisData, IrisPrediction>(testDataScorer);
                ComparePredictions(model);

                // Get feature importance i.e. weight vector
                var summary = ((MulticlassLogisticRegressionPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(7.757864, Convert.ToDouble(summary[0].Value), 5);
            }
        }
        void Extensibility()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new LocalEnvironment())
            {
                var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                Action <IrisData, IrisData> action = (i, j) =>
                {
                    j.Label       = i.Label;
                    j.PetalLength = i.SepalLength > 3 ? i.PetalLength : i.SepalLength;
                    j.PetalWidth  = i.PetalWidth;
                    j.SepalLength = i.SepalLength;
                    j.SepalWidth  = i.SepalWidth;
                };
                var lambda = LambdaTransform.CreateMap(env, loader, action);
                var term   = TermTransform.Create(env, lambda, "Label");
                var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
                             .Transform(term);

                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments {
                    MaxIterations = 100, Shuffle = true, NumThreads = 1
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(scorer);
                var model      = env.CreatePredictionEngine <IrisData, IrisPrediction>(keyToValue);

                var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisData>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == input.Label);
                }
            }
        }
Esempio n. 4
0
        void DecomposableTrainAndPredict()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = new TermTransform(env, loader, "Label");
                var concat  = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments {
                    MaxIterations = 100, Shuffle = true, NumThreads = 1
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Cut out term transform from pipeline.
                var newScorer  = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term);
                var keyToValue = new KeyToValueTransform(env, newScorer, "PredictedLabel");
                var model      = env.CreatePredictionEngine <IrisDataNoLabel, IrisPrediction>(keyToValue);

                var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisDataNoLabel>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == "Iris-setosa");
                }
            }
        }
        public void TrainSentiment()
        {
            using (var env = new ConsoleEnvironment(seed: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    AllowQuoting = false,
                    AllowSparse  = false,
                    Separator    = "tab",
                    HasHeader    = true,
                    Column       = new[]
                    {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 0, Max = 0
                                              } },
                            Type = DataKind.Num
                        },

                        new TextLoader.Column()
                        {
                            Name   = "SentimentText",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 1, Max = 1
                                              } },
                            Type = DataKind.Text
                        }
                    }
                }, new MultiFileSource(_sentimentDataPath));

                var text = TextFeaturizingEstimator.Create(env,
                                                           new TextFeaturizingEstimator.Arguments()
                {
                    Column = new TextFeaturizingEstimator.Column
                    {
                        Name   = "WordEmbeddings",
                        Source = new[] { "SentimentText" }
                    },
                    OutputTokens         = true,
                    KeepPunctuations     = false,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextFeaturizingEstimator.TextNormKind.None,
                    CharFeatureExtractor = null,
                    WordFeatureExtractor = null,
                }, loader);

                var trans = WordEmbeddingsTransform.Create(env,
                                                           new WordEmbeddingsTransform.Arguments()
                {
                    Column = new WordEmbeddingsTransform.Column[1]
                    {
                        new WordEmbeddingsTransform.Column
                        {
                            Name   = "Features",
                            Source = "WordEmbeddings_TransformedText"
                        }
                    },
                    ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe,
                }, text);

                // Train
                var trainer    = new SdcaMultiClassTrainer(env, "Features", "Label", maxIterations: 20);
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");

                var predicted = trainer.Train(trainRoles);
                _consumer.Consume(predicted);
            }
        }
        private static IPredictor TrainSentimentCore()
        {
            var dataPath = s_sentimentDataPath;

            using (var env = new TlcEnvironment(seed: 1))
            {
                // Pipeline
                var loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                {
                    AllowQuoting = false,
                    AllowSparse  = false,
                    Separator    = "tab",
                    HasHeader    = true,
                    Column       = new[]
                    {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 0, Max = 0
                                              } },
                            Type = DataKind.Num
                        },

                        new TextLoader.Column()
                        {
                            Name   = "SentimentText",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 1, Max = 1
                                              } },
                            Type = DataKind.Text
                        }
                    }
                }, new MultiFileSource(dataPath));

                var text = TextTransform.Create(env,
                                                new TextTransform.Arguments()
                {
                    Column = new TextTransform.Column
                    {
                        Name   = "WordEmbeddings",
                        Source = new[] { "SentimentText" }
                    },
                    KeepDiacritics       = false,
                    KeepPunctuations     = false,
                    TextCase             = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
                    OutputTokens         = true,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextTransform.TextNormKind.None,
                    CharFeatureExtractor = null,
                    WordFeatureExtractor = null,
                }, loader);

                var trans = new WordEmbeddingsTransform(env,
                                                        new WordEmbeddingsTransform.Arguments()
                {
                    Column = new WordEmbeddingsTransform.Column[1]
                    {
                        new WordEmbeddingsTransform.Column
                        {
                            Name   = "Features",
                            Source = "WordEmbeddings_TransformedText"
                        }
                    },
                    ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe,
                }, text);

                // Train
                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments()
                {
                    MaxIterations = 20
                });
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                return(trainer.Train(trainRoles));
            }
        }