Exemple #1
0
        public void New_ReconfigurablePrediction()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                var dataReader = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
                                 .Fit(new MultiFileSource(dataPath));

                var data     = dataReader.Read(new MultiFileSource(dataPath));
                var testData = dataReader.Read(new MultiFileSource(testDataPath));

                // Pipeline.
                var pipeline = new MyTextTransform(env, MakeSentimentTextTransformArgs())
                               .Fit(data);

                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments {
                    NumThreads = 1
                }, "Features", "Label");
                var trainData = pipeline.Transform(data);
                var model     = trainer.Fit(trainData);

                var scoredTest = model.Transform(pipeline.Transform(testData));
                var metrics    = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments()).Evaluate(scoredTest, "Label", "Probability");

                var newModel      = new BinaryPredictionTransformer <IPredictorProducing <float> >(env, model.Model, trainData.Schema, model.FeatureColumn, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability);
                var newScoredTest = newModel.Transform(pipeline.Transform(testData));
                var newMetrics    = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments {
                    Threshold = 0.01f, UseRawScoreThreshold = false
                }).Evaluate(newScoredTest, "Label", "Probability");
            }
        }
Exemple #2
0
        BuildAndTrainModel(IDataView trainingSet)
        {
            //IEstimator<ITransformer> dataPrepEstimator =
            //mlContext.Transforms.Concatenate("Features", "Cardinality", "HighestConfidenceScore");

            //// Create data prep transformer
            //ITransformer dataPrepTransformer = dataPrepEstimator.Fit(splitTrainSet);

            //// Apply transforms to training data
            //IDataView transformedTrainingData = dataPrepTransformer.Transform(splitTrainSet);
            // var estimator = mlContext.Transforms.Text.FeaturizeText(outputColumnName: "Features", inputColumnName: nameof(IDataViewWrapper.Features));
            var trainer = (mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(labelColumnName: "Label", featureColumnName: "Features"));

            Console.WriteLine("=============== Create and Train the Model ===============");

            BinaryPredictionTransformer <Microsoft.ML.Calibrators.CalibratedModelParametersBase <LinearBinaryModelParameters, Microsoft.ML.Calibrators.PlattCalibrator> >
            model = trainer.Fit(trainingSet);

            var weights = model.Model.SubModel.Weights;
            var bias    = model.Model.SubModel.Bias;

            Output("ML.Net Weights: ");
            foreach (float w in weights)
            {
                Output(w.ToString());
            }
            Output(bias.ToString());

            Console.WriteLine("=============== End of training ===============");
            Console.WriteLine();
            return(model);
        }
        public void ReconfigurablePrediction()
        {
            var ml         = new MLContext(seed: 1, conc: 1);
            var dataReader = ml.Data.ReadFromTextFile <SentimentData>(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true);

            var data     = ml.Data.ReadFromTextFile <SentimentData>(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true);
            var testData = ml.Data.ReadFromTextFile <SentimentData>(GetDataPath(TestDatasets.Sentiment.testFilename), hasHeader: true);

            // Pipeline.
            var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText")
                           .Fit(data);

            var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
                new SdcaBinaryTrainer.Options {
                NumThreads = 1
            });

            var trainData = ml.Data.Cache(pipeline.Transform(data)); // Cache the data right before the trainer to boost the training speed.
            var model     = trainer.Fit(trainData);

            var scoredTest = model.Transform(pipeline.Transform(testData));
            var metrics    = ml.BinaryClassification.Evaluate(scoredTest);

            var newModel      = new BinaryPredictionTransformer <IPredictorProducing <float> >(ml, model.Model, trainData.Schema, model.FeatureColumn, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability);
            var newScoredTest = newModel.Transform(pipeline.Transform(testData));
            var newMetrics    = ml.BinaryClassification.Evaluate(scoredTest);
        }
Exemple #4
0
 public BinaryPredictionTransformer <TModel> ChangeModelThreshold <TModel>(BinaryPredictionTransformer <TModel> model, float threshold)
     where TModel : class
 {
     if (model.Threshold == threshold)
     {
         return(model);
     }
     return(new BinaryPredictionTransformer <TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumnName, threshold, model.ThresholdColumn));
 }
Exemple #5
0
        static void Main(string[] args)
        {
            if (false == File.Exists(Program.TrainDataPath))
            {
                using (var client = new WebClient())
                {
                    client.DownloadFile(@"https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip", "spam.zip");
                }

                ZipFile.ExtractToDirectory("spam.zip", Program.DataDirectoryPath);
            }

            var context = new MLContext();

            var reader = new TextLoader(context, new TextLoader.Arguments()
            {
                Separator = "tab",
                HasHeader = true,
                Column    = new[]
                {
                    new TextLoader.Column("Label", DataKind.Text, 0),
                    new TextLoader.Column("Message", DataKind.Text, 1)
                }
            });

            var data = reader.Read(new MultiFileSource(Program.TrainDataPath));

            var estimator = context.Transforms.CustomMapping <MyInput, MyOutput>(MyLambda.MyAction, "MyLambda")
                            .Append(context.Transforms.Text.FeaturizeText("Message", "Features"))
                            .Append(context.BinaryClassification.Trainers.StochasticDualCoordinateAscent());

            var cvResult = context.BinaryClassification.CrossValidate(data, estimator, numFolds: 5);
            var aucs     = cvResult.Select(r => r.metrics.Auc);

            Console.WriteLine($"The AUC is {aucs.Average()}");

            var model  = estimator.Fit(data);
            var inPipe = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());

            var lastTransformer = new BinaryPredictionTransformer <IPredictorProducing <float> >(
                context,
                model.LastTransformer.Model,
                inPipe.GetOutputSchema(data.Schema),
                model.LastTransformer.FeatureColumn,
                threshold: 0.15f,
                thresholdColumn: DefaultColumnNames.Probability);
            var parts = model.ToArray();

            parts[parts.Length - 1] = lastTransformer;
            var newModel  = new TransformerChain <ITransformer>(parts);
            var predictor = newModel.MakePredictionFunction <Input, Prediction>(context);

            Program.ClassifyMessage(predictor, "That's a great idea. It should work.");
            Program.ClassifyMessage(predictor, "Free medicine winner! Congratulations");
            Program.ClassifyMessage(predictor, "Yes we should meet over the weekend");
            Program.ClassifyMessage(predictor, "You win pills and free entry vouchers");
        }
Exemple #6
0
        private static PredictionFunction <SpamInput, SpamPrediction> GetPredictor()
        {
            if (_predictor == null)
            {
                // Set up the MLContext, which is a catalog of components in ML.NET.
                var mlContext = new MLContext();

                // Create the reader and define which columns from the file should be read.
                var reader = new TextLoader(mlContext, new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.Text, 0),
                        new TextLoader.Column("Message", DataKind.Text, 1)
                    }
                });

                var data = reader.Read(new MultiFileSource(TrainDataPath));

                // Create the estimator which converts the text label to boolean, featurizes the text, and adds a linear trainer.
                var estimator = mlContext.Transforms.CustomMapping <MyInput, MyOutput>(MyLambda.MyAction, "MyLambda")
                                .Append(mlContext.Transforms.Text.FeaturizeText("Message", "Features"))
                                .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());

                // Evaluate the model using cross-validation.
                // Cross-validation splits our dataset into 'folds', trains a model on some folds and
                // evaluates it on the remaining fold. We are using 5 folds so we get back 5 sets of scores.
                // Let's compute the average AUC, which should be between 0.5 and 1 (higher is better).
                var cvResults = mlContext.BinaryClassification.CrossValidate(data, estimator, numFolds: 5);
                var aucs      = cvResults.Select(r => r.metrics.Auc);

                // Now let's train a model on the full dataset to help us get better results
                var model = estimator.Fit(data);

                // The dataset we have is skewed, as there are many more non-spam messages than spam messages.
                // While our model is relatively good at detecting the difference, this skewness leads it to always
                // say the message is not spam. We deal with this by lowering the threshold of the predictor. In reality,
                // it is useful to look at the precision-recall curve to identify the best possible threshold.
                var inPipe          = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());
                var lastTransformer = new BinaryPredictionTransformer <IPredictorProducing <float> >(mlContext, model.LastTransformer.Model, inPipe.GetOutputSchema(data.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f, thresholdColumn: DefaultColumnNames.Probability);

                ITransformer[] parts = model.ToArray();
                parts[parts.Length - 1] = lastTransformer;
                var newModel = new TransformerChain <ITransformer>(parts);

                // Create a PredictionFunction from our model
                _predictor = newModel.MakePredictionFunction <SpamInput, SpamPrediction>(mlContext);
            }

            return(_predictor);
        }
        public void FastTreeClassificationIntrospectiveTraining()
        {
            var ml   = new MLContext(seed: 1, conc: 1);
            var data = ml.Data.ReadFromTextFile <SentimentData>(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true);

            var trainer = ml.BinaryClassification.Trainers.FastTree(numLeaves: 5, numTrees: 3);

            BinaryPredictionTransformer <IPredictorWithFeatureWeights <float> > pred = null;

            var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText")
                           .AppendCacheCheckpoint(ml)
                           .Append(trainer.WithOnFitDelegate(p => pred = p));

            // Train.
            var model = pipeline.Fit(data);

            // Extract the learned GBDT model.
            var treeCollection = ((FastTreeBinaryModelParameters)((Internal.Calibration.FeatureWeightsCalibratedPredictor)pred.Model).SubPredictor).TrainedTreeEnsemble;

            // Inspect properties in the extracted model.
            Assert.Equal(3, treeCollection.Trees.Count);
            Assert.Equal(3, treeCollection.TreeWeights.Count);
            Assert.Equal(0, treeCollection.Bias);
            Assert.All(treeCollection.TreeWeights, weight => Assert.Equal(1.0, weight));

            // Inspect the last tree.
            var tree = treeCollection.Trees[2];

            Assert.Equal(5, tree.NumLeaves);
            Assert.Equal(4, tree.NumNodes);
            Assert.Equal(tree.LteChild, new int[] { 2, -2, -1, -3 });
            Assert.Equal(tree.GtChild, new int[] { 1, 3, -4, -5 });
            Assert.Equal(tree.NumericalSplitFeatureIndexes, new int[] { 14, 294, 633, 266 });
            var expectedThresholds = new float[] { 0.0911167f, 0.06509889f, 0.019873254f, 0.0361835f };

            for (int i = 0; i < tree.NumNodes; ++i)
            {
                Assert.Equal(expectedThresholds[i], tree.NumericalSplitThresholds[i], 6);
            }
            Assert.All(tree.CategoricalSplitFlags, flag => Assert.False(flag));

            Assert.Equal(0, tree.GetCategoricalSplitFeaturesAt(0).Count);
            Assert.Equal(0, tree.GetCategoricalCategoricalSplitFeatureRangeAt(0).Count);
        }
Exemple #8
0
        SetThreshold(BinaryPredictionTransformer <Microsoft.ML.Calibrators.CalibratedModelParametersBase <LinearBinaryModelParameters, Microsoft.ML.Calibrators.PlattCalibrator> > lrModel, IDataView testSet)
        {
            float  threshold          = 1.0F;
            double currentSpecificity = 1.0;

            do
            {
                threshold -= tick;

                CalibratedBinaryClassificationMetrics metrics = GetMetrics((ITransformer)lrModel, testSet);
                currentSpecificity = metrics.NegativeRecall;

                double AUC = metrics.AreaUnderRocCurve;

                Console.WriteLine("Threshold: {0:0.00}; Specificity: {1:0.00}; AUC: {2:0.00}", threshold, currentSpecificity, AUC);


                lrModel = mlContext.BinaryClassification.ChangeModelThreshold(lrModel, threshold);

                Thread.Sleep(5);
            } while (currentSpecificity > minimumSpecificity);
            return(lrModel);
        }
        public void New_ReconfigurablePrediction()
        {
            var ml         = new MLContext(seed: 1, conc: 1);
            var dataReader = ml.Data.TextReader(MakeSentimentTextLoaderArgs());

            var data     = dataReader.Read(GetDataPath(TestDatasets.Sentiment.trainFilename));
            var testData = dataReader.Read(GetDataPath(TestDatasets.Sentiment.testFilename));

            // Pipeline.
            var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features")
                           .Fit(data);

            var trainer   = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent(advancedSettings: (s) => s.NumThreads = 1);
            var trainData = pipeline.Transform(data);
            var model     = trainer.Fit(trainData);

            var scoredTest = model.Transform(pipeline.Transform(testData));
            var metrics    = ml.BinaryClassification.Evaluate(scoredTest);

            var newModel      = new BinaryPredictionTransformer <IPredictorProducing <float> >(ml, model.Model, trainData.Schema, model.FeatureColumn, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability);
            var newScoredTest = newModel.Transform(pipeline.Transform(testData));
            var newMetrics    = ml.BinaryClassification.Evaluate(scoredTest);
        }
        static void Main(string[] args)
        {
            // Download the dataset if it doesn't exist.
            if (!File.Exists(TrainDataPath))
            {
                using (var client = new WebClient())
                {
                    //The code below will download a dataset from a third-party, UCI (link), and may be governed by separate third-party terms.
                    //By proceeding, you agree to those separate terms.
                    client.DownloadFile("https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip", "spam.zip");
                }

                ZipFile.ExtractToDirectory("spam.zip", DataDirectoryPath);
            }

            // Set up the MLContext, which is a catalog of components in ML.NET.
            MLContext mlContext = new MLContext();

            // Specify the schema for spam data and read it into DataView.
            var data = mlContext.Data.ReadFromTextFile <SpamInput>(path: TrainDataPath, hasHeader: true, separatorChar: '\t');

            // Create the estimator which converts the text label to boolean, featurizes the text, and adds a linear trainer.
            var dataProcessPipeLine = mlContext.Transforms.CustomMapping <MyInput, MyOutput>(mapAction: MyLambda.MyAction, contractName: "MyLambda")
                                      .Append(mlContext.Transforms.Text.FeaturizeText(outputColumnName: DefaultColumnNames.Features, inputColumnName: nameof(SpamInput.Message)));

            //Create the training pipeline
            Console.WriteLine("=============== Training the model ===============");
            var trainingPipeLine = dataProcessPipeLine.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());

            // Evaluate the model using cross-validation.
            // Cross-validation splits our dataset into 'folds', trains a model on some folds and
            // evaluates it on the remaining fold. We are using 5 folds so we get back 5 sets of scores.
            // Let's compute the average AUC, which should be between 0.5 and 1 (higher is better).
            Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
            var crossValidationResults = mlContext.BinaryClassification.CrossValidate(data: data, estimator: trainingPipeLine, numFolds: 5);
            var aucs = crossValidationResults.Select(r => r.metrics.Auc);

            Console.WriteLine("The AUC is {0}", aucs.Average());

            // Now let's train a model on the full dataset to help us get better results
            var model = trainingPipeLine.Fit(data);

            // The dataset we have is skewed, as there are many more non-spam messages than spam messages.
            // While our model is relatively good at detecting the difference, this skewness leads it to always
            // say the message is not spam. We deal with this by lowering the threshold of the predictor. In reality,
            // it is useful to look at the precision-recall curve to identify the best possible threshold.
            var inPipe          = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());
            var lastTransformer = new BinaryPredictionTransformer <IPredictorProducing <float> >(mlContext, model.LastTransformer.Model, inPipe.GetOutputSchema(data.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f, thresholdColumn: DefaultColumnNames.Probability);

            ITransformer[] parts = model.ToArray();
            parts[parts.Length - 1] = lastTransformer;
            ITransformer newModel = new TransformerChain <ITransformer>(parts);

            // Create a PredictionFunction from our model
            var predictor = newModel.CreatePredictionEngine <SpamInput, SpamPrediction>(mlContext);

            Console.WriteLine("=============== Predictions for below data===============");
            // Test a few examples
            ClassifyMessage(predictor, "That's a great idea. It should work.");
            ClassifyMessage(predictor, "free medicine winner! congratulations");
            ClassifyMessage(predictor, "Yes we should meet over the weekend!");
            ClassifyMessage(predictor, "you win pills and free entry vouchers");

            Console.WriteLine("=============== End of process, hit any key to finish =============== ");
            Console.ReadLine();
        }
Exemple #11
0
        public void FastTreeClassificationIntrospectiveTraining()
        {
            var ml   = new MLContext(seed: 1, conc: 1);
            var data = ml.Data.LoadFromTextFile <SentimentData>(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true, allowQuoting: true);

            var trainer = ml.BinaryClassification.Trainers.FastTree(numberOfLeaves: 5, numberOfTrees: 3);

            BinaryPredictionTransformer <CalibratedModelParametersBase <FastTreeBinaryModelParameters, PlattCalibrator> > pred = null;

            var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText")
                           .AppendCacheCheckpoint(ml)
                           .Append(trainer.WithOnFitDelegate(p => pred = p));

            // Train.
            var model = pipeline.Fit(data);

            // Extract the learned GBDT model.
            var treeCollection = pred.Model.SubModel.TrainedTreeEnsemble;

            // Inspect properties in the extracted model.
            Assert.Equal(3, treeCollection.Trees.Count);
            Assert.Equal(3, treeCollection.TreeWeights.Count);
            Assert.Equal(0, treeCollection.Bias);
            Assert.All(treeCollection.TreeWeights, weight => Assert.Equal(1.0, weight));

            // Inspect the last tree.
            var tree = treeCollection.Trees[2];

            Assert.Equal(5, tree.NumberOfLeaves);
            Assert.Equal(4, tree.NumberOfNodes);
            Assert.Equal(tree.LeftChild, new int[] { 2, -2, -1, -3 });
            Assert.Equal(tree.RightChild, new int[] { 1, 3, -4, -5 });
            Assert.Equal(tree.NumericalSplitFeatureIndexes, new int[] { 14, 294, 633, 266 });
            Assert.Equal(tree.SplitGains.Count, tree.NumberOfNodes);
            Assert.Equal(tree.NumericalSplitThresholds.Count, tree.NumberOfNodes);
            var expectedSplitGains = new double[] { 0.52634223978445616, 0.45899249367725858, 0.44142707650267105, 0.38348634823264854 };
            var expectedThresholds = new float[] { 0.0911167f, 0.06509889f, 0.019873254f, 0.0361835f };

            for (int i = 0; i < tree.NumberOfNodes; ++i)
            {
                Assert.Equal(expectedSplitGains[i], tree.SplitGains[i], 6);
                Assert.Equal(expectedThresholds[i], tree.NumericalSplitThresholds[i], 6);
            }
            Assert.All(tree.CategoricalSplitFlags, flag => Assert.False(flag));

            Assert.Equal(0, tree.GetCategoricalSplitFeaturesAt(0).Count);
            Assert.Equal(0, tree.GetCategoricalCategoricalSplitFeatureRangeAt(0).Count);
        }
        private static IDataView GetBinaryMetrics(
            IHostEnvironment env,
            IPredictor predictor,
            RoleMappedData roleMappedData,
            PermutationFeatureImportanceArguments input)
        {
            var roles             = roleMappedData.Schema.GetColumnRoleNames();
            var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value;
            var labelColumnName   = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value;
            var pred = new BinaryPredictionTransformer <IPredictorProducing <float> >(
                env, predictor as IPredictorProducing <float>, roleMappedData.Data.Schema, featureColumnName);
            var binaryCatalog      = new BinaryClassificationCatalog(env);
            var permutationMetrics = binaryCatalog
                                     .PermutationFeatureImportance(pred,
                                                                   roleMappedData.Data,
                                                                   labelColumnName: labelColumnName,
                                                                   useFeatureWeightFilter: input.UseFeatureWeightFilter,
                                                                   numberOfExamplesToUse: input.NumberOfExamplesToUse,
                                                                   permutationCount: input.PermutationCount);

            var slotNames = GetSlotNames(roleMappedData.Schema);

            Contracts.Assert(slotNames.Length == permutationMetrics.Length,
                             "Mismatch between number of feature slots and number of features permuted.");

            List <BinaryMetrics> metrics = new List <BinaryMetrics>();

            for (int i = 0; i < permutationMetrics.Length; i++)
            {
                if (string.IsNullOrWhiteSpace(slotNames[i]))
                {
                    continue;
                }
                var pMetric = permutationMetrics[i];
                metrics.Add(new BinaryMetrics
                {
                    FeatureName             = slotNames[i],
                    AreaUnderRocCurve       = pMetric.AreaUnderRocCurve.Mean,
                    AreaUnderRocCurveStdErr = pMetric.AreaUnderRocCurve.StandardError,
                    Accuracy                = pMetric.Accuracy.Mean,
                    AccuracyStdErr          = pMetric.Accuracy.StandardError,
                    PositivePrecision       = pMetric.PositivePrecision.Mean,
                    PositivePrecisionStdErr = pMetric.PositivePrecision.StandardError,
                    PositiveRecall          = pMetric.PositiveRecall.Mean,
                    PositiveRecallStdErr    = pMetric.PositiveRecall.StandardError,
                    NegativePrecision       = pMetric.NegativePrecision.Mean,
                    NegativePrecisionStdErr = pMetric.NegativePrecision.StandardError,
                    NegativeRecall          = pMetric.NegativeRecall.Mean,
                    NegativeRecallStdErr    = pMetric.NegativeRecall.StandardError,
                    F1Score       = pMetric.F1Score.Mean,
                    F1ScoreStdErr = pMetric.F1Score.StandardError,
                    AreaUnderPrecisionRecallCurve       = pMetric.AreaUnderPrecisionRecallCurve.Mean,
                    AreaUnderPrecisionRecallCurveStdErr = pMetric.AreaUnderPrecisionRecallCurve.StandardError
                });
            }

            var dataOps = new DataOperationsCatalog(env);
            var result  = dataOps.LoadFromEnumerable(metrics);

            return(result);
        }
Exemple #13
0
        static void Main(string[] args)
        {
            DownloadTrainingData();
            // 创建上下文
            MLContext mlContext = new MLContext();
            // 创建文本数据加载器
            TextLoader textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
            {
                Separator = "tab",
                HasHeader = false,
                Column    = new[]
                {
                    new TextLoader.Column("Label", DataKind.Text, 0),
                    new TextLoader.Column("Message", DataKind.Text, 1)
                }
            });

            // 读取数据集
            var fullData = textLoader.Read(DataPath);
            // 特征工程和指定训练算法
            var estimator = mlContext.Transforms.CustomMapping <MyInput, MyOutput>(MyLambda.MyAction, "MyLambda")

                            .Append(mlContext.Transforms.Text.FeaturizeText("Message", "Features"))
                            .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());
            // 使用交叉验证进行模型评估
            var cvResults = mlContext.BinaryClassification.CrossValidate(fullData, estimator, numFolds: 5);
            var aucs      = cvResults.Select(r => r.metrics.Auc);

            Console.WriteLine($"The AUC is {aucs.Average()}");

            // 训练
            var model = estimator.Fit(fullData);


            var inPipe          = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());
            var lastTransFormer = new BinaryPredictionTransformer <IPredictorProducing <float> >(mlContext,
                                                                                                 model.LastTransformer.Model,
                                                                                                 inPipe.GetOutputSchema(fullData.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f);
            var parts = model.ToArray();

            parts[parts.Length - 1] = lastTransFormer;
            var newModel = new TransformerChain <ITransformer>(parts);

            var predictor = newModel.MakePredictionFunction <SpamData, SpamPrediction>(mlContext);

            var testMsgs = new string[]
            {
                "That's a great idea. It should work.",
                "free medicine winner! congratulations",
                "Yes we should meet over the weekend!",
                "you win pills and free entry vouchers"
            };

            foreach (var message in testMsgs)
            {
                var input = new SpamData {
                    Message = message
                };
                var prediction = predictor.Predict(input);

                Console.WriteLine("The message '{0}' is spam? {1}!", input.Message, prediction.IsSpam.ToString());
            }

            Console.WriteLine("Hello World!");
        }
Exemple #14
0
        static void Main(string[] args)
        {
            // Download the dataset if it doesn't exist.
            if (!File.Exists(TrainDataPath))
            {
                using (var client = new WebClient())
                {
                    client.DownloadFile("https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip", "spam.zip");
                }

                ZipFile.ExtractToDirectory("spam.zip", DataDirectoryPath);
            }

            // Set up the MLContext, which is a catalog of components in ML.NET.
            var mlContext = new MLContext();

            // Create the reader and define which columns from the file should be read.
            var reader = new TextLoader(mlContext, new TextLoader.Arguments()
            {
                Separator = "tab",
                HasHeader = true,
                Column    = new[]
                {
                    new TextLoader.Column("Label", DataKind.Text, 0),
                    new TextLoader.Column("Message", DataKind.Text, 1)
                }
            });

            var data = reader.Read(new MultiFileSource(TrainDataPath));

            // Create the estimator which converts the text label to boolean, featurizes the text, and adds a linear trainer.
            var estimator = mlContext.Transforms.CustomMapping <MyInput, MyOutput>(MyLambda.MyAction, "MyLambda")
                            .Append(mlContext.Transforms.Text.FeaturizeText("Message", "Features"))
                            .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());

            // Evaluate the model using cross-validation.
            // Cross-validation splits our dataset into 'folds', trains a model on some folds and
            // evaluates it on the remaining fold. We are using 5 folds so we get back 5 sets of scores.
            // Let's compute the average AUC, which should be between 0.5 and 1 (higher is better).
            var cvResults = mlContext.BinaryClassification.CrossValidate(data, estimator, numFolds: 5);
            var aucs      = cvResults.Select(r => r.metrics.Auc);

            Console.WriteLine("The AUC is {0}", aucs.Average());

            // Now let's train a model on the full dataset to help us get better results
            var model = estimator.Fit(data);

            // The dataset we have is skewed, as there are many more non-spam messages than spam messages.
            // While our model is relatively good at detecting the difference, this skewness leads it to always
            // say the message is not spam. We deal with this by lowering the threshold of the predictor. In reality,
            // it is useful to look at the precision-recall curve to identify the best possible threshold.
            var inPipe          = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());
            var lastTransformer = new BinaryPredictionTransformer <IPredictorProducing <float> >(mlContext, model.LastTransformer.Model, inPipe.GetOutputSchema(data.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f, thresholdColumn: DefaultColumnNames.Probability);

            ITransformer[] parts = model.ToArray();
            parts[parts.Length - 1] = lastTransformer;
            var newModel = new TransformerChain <ITransformer>(parts);

            // Create a PredictionFunction from our model
            var predictor = newModel.MakePredictionFunction <SpamInput, SpamPrediction>(mlContext);

            // Test a few examples
            ClassifyMessage(predictor, "That's a great idea. It should work.");
            ClassifyMessage(predictor, "free medicine winner! congratulations");
            ClassifyMessage(predictor, "Yes we should meet over the weekend!");
            ClassifyMessage(predictor, "you win pills and free entry vouchers");

            Console.ReadLine();
        }