Example #1
0
        public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var chain = new TransformerChain <ITransformer>(HashingTransformer, CountTable);

            return(chain.GetOutputSchema(inputSchema));
        }
Example #2
0
        static void Main(string[] args)
        {
            if (false == File.Exists(Program.TrainDataPath))
            {
                using (var client = new WebClient())
                {
                    client.DownloadFile(@"https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip", "spam.zip");
                }

                ZipFile.ExtractToDirectory("spam.zip", Program.DataDirectoryPath);
            }

            var context = new MLContext();

            var reader = new TextLoader(context, new TextLoader.Arguments()
            {
                Separator = "tab",
                HasHeader = true,
                Column    = new[]
                {
                    new TextLoader.Column("Label", DataKind.Text, 0),
                    new TextLoader.Column("Message", DataKind.Text, 1)
                }
            });

            var data = reader.Read(new MultiFileSource(Program.TrainDataPath));

            var estimator = context.Transforms.CustomMapping <MyInput, MyOutput>(MyLambda.MyAction, "MyLambda")
                            .Append(context.Transforms.Text.FeaturizeText("Message", "Features"))
                            .Append(context.BinaryClassification.Trainers.StochasticDualCoordinateAscent());

            var cvResult = context.BinaryClassification.CrossValidate(data, estimator, numFolds: 5);
            var aucs     = cvResult.Select(r => r.metrics.Auc);

            Console.WriteLine($"The AUC is {aucs.Average()}");

            var model  = estimator.Fit(data);
            var inPipe = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());

            var lastTransformer = new BinaryPredictionTransformer <IPredictorProducing <float> >(
                context,
                model.LastTransformer.Model,
                inPipe.GetOutputSchema(data.Schema),
                model.LastTransformer.FeatureColumn,
                threshold: 0.15f,
                thresholdColumn: DefaultColumnNames.Probability);
            var parts = model.ToArray();

            parts[parts.Length - 1] = lastTransformer;
            var newModel  = new TransformerChain <ITransformer>(parts);
            var predictor = newModel.MakePredictionFunction <Input, Prediction>(context);

            Program.ClassifyMessage(predictor, "That's a great idea. It should work.");
            Program.ClassifyMessage(predictor, "Free medicine winner! Congratulations");
            Program.ClassifyMessage(predictor, "Yes we should meet over the weekend");
            Program.ClassifyMessage(predictor, "You win pills and free entry vouchers");
        }
Example #3
0
 public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => _transformer.GetOutputSchema(inputSchema);
 public Schema GetOutputSchema(Schema inputSchema) => _transformer.GetOutputSchema(inputSchema);
        static void Main(string[] args)
        {
            // Download the dataset if it doesn't exist.
            if (!File.Exists(TrainDataPath))
            {
                using (var client = new WebClient())
                {
                    //The code below will download a dataset from a third-party, UCI (link), and may be governed by separate third-party terms.
                    //By proceeding, you agree to those separate terms.
                    client.DownloadFile("https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip", "spam.zip");
                }

                ZipFile.ExtractToDirectory("spam.zip", DataDirectoryPath);
            }

            // Set up the MLContext, which is a catalog of components in ML.NET.
            MLContext mlContext = new MLContext();

            // Specify the schema for spam data and read it into DataView.
            var data = mlContext.Data.ReadFromTextFile <SpamInput>(path: TrainDataPath, hasHeader: true, separatorChar: '\t');

            // Create the estimator which converts the text label to boolean, featurizes the text, and adds a linear trainer.
            var dataProcessPipeLine = mlContext.Transforms.CustomMapping <MyInput, MyOutput>(mapAction: MyLambda.MyAction, contractName: "MyLambda")
                                      .Append(mlContext.Transforms.Text.FeaturizeText(outputColumnName: DefaultColumnNames.Features, inputColumnName: nameof(SpamInput.Message)));

            //Create the training pipeline
            Console.WriteLine("=============== Training the model ===============");
            var trainingPipeLine = dataProcessPipeLine.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());

            // Evaluate the model using cross-validation.
            // Cross-validation splits our dataset into 'folds', trains a model on some folds and
            // evaluates it on the remaining fold. We are using 5 folds so we get back 5 sets of scores.
            // Let's compute the average AUC, which should be between 0.5 and 1 (higher is better).
            Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
            var crossValidationResults = mlContext.BinaryClassification.CrossValidate(data: data, estimator: trainingPipeLine, numFolds: 5);
            var aucs = crossValidationResults.Select(r => r.metrics.Auc);

            Console.WriteLine("The AUC is {0}", aucs.Average());

            // Now let's train a model on the full dataset to help us get better results
            var model = trainingPipeLine.Fit(data);

            // The dataset we have is skewed, as there are many more non-spam messages than spam messages.
            // While our model is relatively good at detecting the difference, this skewness leads it to always
            // say the message is not spam. We deal with this by lowering the threshold of the predictor. In reality,
            // it is useful to look at the precision-recall curve to identify the best possible threshold.
            var inPipe          = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());
            var lastTransformer = new BinaryPredictionTransformer <IPredictorProducing <float> >(mlContext, model.LastTransformer.Model, inPipe.GetOutputSchema(data.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f, thresholdColumn: DefaultColumnNames.Probability);

            ITransformer[] parts = model.ToArray();
            parts[parts.Length - 1] = lastTransformer;
            ITransformer newModel = new TransformerChain <ITransformer>(parts);

            // Create a PredictionFunction from our model
            var predictor = newModel.CreatePredictionEngine <SpamInput, SpamPrediction>(mlContext);

            Console.WriteLine("=============== Predictions for below data===============");
            // Test a few examples
            ClassifyMessage(predictor, "That's a great idea. It should work.");
            ClassifyMessage(predictor, "free medicine winner! congratulations");
            ClassifyMessage(predictor, "Yes we should meet over the weekend!");
            ClassifyMessage(predictor, "you win pills and free entry vouchers");

            Console.WriteLine("=============== End of process, hit any key to finish =============== ");
            Console.ReadLine();
        }
Example #6
0
        private static PredictionFunction <SpamInput, SpamPrediction> GetPredictor()
        {
            if (_predictor == null)
            {
                // Set up the MLContext, which is a catalog of components in ML.NET.
                var mlContext = new MLContext();

                // Create the reader and define which columns from the file should be read.
                var reader = new TextLoader(mlContext, new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.Text, 0),
                        new TextLoader.Column("Message", DataKind.Text, 1)
                    }
                });

                var data = reader.Read(new MultiFileSource(TrainDataPath));

                // Create the estimator which converts the text label to boolean, featurizes the text, and adds a linear trainer.
                var estimator = mlContext.Transforms.CustomMapping <MyInput, MyOutput>(MyLambda.MyAction, "MyLambda")
                                .Append(mlContext.Transforms.Text.FeaturizeText("Message", "Features"))
                                .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());

                // Evaluate the model using cross-validation.
                // Cross-validation splits our dataset into 'folds', trains a model on some folds and
                // evaluates it on the remaining fold. We are using 5 folds so we get back 5 sets of scores.
                // Let's compute the average AUC, which should be between 0.5 and 1 (higher is better).
                var cvResults = mlContext.BinaryClassification.CrossValidate(data, estimator, numFolds: 5);
                var aucs      = cvResults.Select(r => r.metrics.Auc);

                // Now let's train a model on the full dataset to help us get better results
                var model = estimator.Fit(data);

                // The dataset we have is skewed, as there are many more non-spam messages than spam messages.
                // While our model is relatively good at detecting the difference, this skewness leads it to always
                // say the message is not spam. We deal with this by lowering the threshold of the predictor. In reality,
                // it is useful to look at the precision-recall curve to identify the best possible threshold.
                var inPipe          = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());
                var lastTransformer = new BinaryPredictionTransformer <IPredictorProducing <float> >(mlContext, model.LastTransformer.Model, inPipe.GetOutputSchema(data.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f, thresholdColumn: DefaultColumnNames.Probability);

                ITransformer[] parts = model.ToArray();
                parts[parts.Length - 1] = lastTransformer;
                var newModel = new TransformerChain <ITransformer>(parts);

                // Create a PredictionFunction from our model
                _predictor = newModel.MakePredictionFunction <SpamInput, SpamPrediction>(mlContext);
            }

            return(_predictor);
        }
Example #7
0
        static void Main(string[] args)
        {
            DownloadTrainingData();
            // 创建上下文
            MLContext mlContext = new MLContext();
            // 创建文本数据加载器
            TextLoader textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
            {
                Separator = "tab",
                HasHeader = false,
                Column    = new[]
                {
                    new TextLoader.Column("Label", DataKind.Text, 0),
                    new TextLoader.Column("Message", DataKind.Text, 1)
                }
            });

            // 读取数据集
            var fullData = textLoader.Read(DataPath);
            // 特征工程和指定训练算法
            var estimator = mlContext.Transforms.CustomMapping <MyInput, MyOutput>(MyLambda.MyAction, "MyLambda")

                            .Append(mlContext.Transforms.Text.FeaturizeText("Message", "Features"))
                            .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());
            // 使用交叉验证进行模型评估
            var cvResults = mlContext.BinaryClassification.CrossValidate(fullData, estimator, numFolds: 5);
            var aucs      = cvResults.Select(r => r.metrics.Auc);

            Console.WriteLine($"The AUC is {aucs.Average()}");

            // 训练
            var model = estimator.Fit(fullData);


            var inPipe          = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());
            var lastTransFormer = new BinaryPredictionTransformer <IPredictorProducing <float> >(mlContext,
                                                                                                 model.LastTransformer.Model,
                                                                                                 inPipe.GetOutputSchema(fullData.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f);
            var parts = model.ToArray();

            parts[parts.Length - 1] = lastTransFormer;
            var newModel = new TransformerChain <ITransformer>(parts);

            var predictor = newModel.MakePredictionFunction <SpamData, SpamPrediction>(mlContext);

            var testMsgs = new string[]
            {
                "That's a great idea. It should work.",
                "free medicine winner! congratulations",
                "Yes we should meet over the weekend!",
                "you win pills and free entry vouchers"
            };

            foreach (var message in testMsgs)
            {
                var input = new SpamData {
                    Message = message
                };
                var prediction = predictor.Predict(input);

                Console.WriteLine("The message '{0}' is spam? {1}!", input.Message, prediction.IsSpam.ToString());
            }

            Console.WriteLine("Hello World!");
        }
Example #8
0
        static void Main(string[] args)
        {
            // Download the dataset if it doesn't exist.
            if (!File.Exists(TrainDataPath))
            {
                using (var client = new WebClient())
                {
                    client.DownloadFile("https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip", "spam.zip");
                }

                ZipFile.ExtractToDirectory("spam.zip", DataDirectoryPath);
            }

            // Set up the MLContext, which is a catalog of components in ML.NET.
            var mlContext = new MLContext();

            // Create the reader and define which columns from the file should be read.
            var reader = new TextLoader(mlContext, new TextLoader.Arguments()
            {
                Separator = "tab",
                HasHeader = true,
                Column    = new[]
                {
                    new TextLoader.Column("Label", DataKind.Text, 0),
                    new TextLoader.Column("Message", DataKind.Text, 1)
                }
            });

            var data = reader.Read(new MultiFileSource(TrainDataPath));

            // Create the estimator which converts the text label to boolean, featurizes the text, and adds a linear trainer.
            var estimator = mlContext.Transforms.CustomMapping <MyInput, MyOutput>(MyLambda.MyAction, "MyLambda")
                            .Append(mlContext.Transforms.Text.FeaturizeText("Message", "Features"))
                            .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());

            // Evaluate the model using cross-validation.
            // Cross-validation splits our dataset into 'folds', trains a model on some folds and
            // evaluates it on the remaining fold. We are using 5 folds so we get back 5 sets of scores.
            // Let's compute the average AUC, which should be between 0.5 and 1 (higher is better).
            var cvResults = mlContext.BinaryClassification.CrossValidate(data, estimator, numFolds: 5);
            var aucs      = cvResults.Select(r => r.metrics.Auc);

            Console.WriteLine("The AUC is {0}", aucs.Average());

            // Now let's train a model on the full dataset to help us get better results
            var model = estimator.Fit(data);

            // The dataset we have is skewed, as there are many more non-spam messages than spam messages.
            // While our model is relatively good at detecting the difference, this skewness leads it to always
            // say the message is not spam. We deal with this by lowering the threshold of the predictor. In reality,
            // it is useful to look at the precision-recall curve to identify the best possible threshold.
            var inPipe          = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());
            var lastTransformer = new BinaryPredictionTransformer <IPredictorProducing <float> >(mlContext, model.LastTransformer.Model, inPipe.GetOutputSchema(data.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f, thresholdColumn: DefaultColumnNames.Probability);

            ITransformer[] parts = model.ToArray();
            parts[parts.Length - 1] = lastTransformer;
            var newModel = new TransformerChain <ITransformer>(parts);

            // Create a PredictionFunction from our model
            var predictor = newModel.MakePredictionFunction <SpamInput, SpamPrediction>(mlContext);

            // Test a few examples
            ClassifyMessage(predictor, "That's a great idea. It should work.");
            ClassifyMessage(predictor, "free medicine winner! congratulations");
            ClassifyMessage(predictor, "Yes we should meet over the weekend!");
            ClassifyMessage(predictor, "you win pills and free entry vouchers");

            Console.ReadLine();
        }