public void SetupPredictBenchmarks()
        {
            _trainedModel     = Train(_dataPath);
            _predictionEngine = _trainedModel.CreatePredictionEngine <IrisData, IrisPrediction>(_env);
            _consumer.Consume(_predictionEngine.Predict(_example));

            var reader = new TextLoader(_env,
                                        columns: new[]
            {
                new TextLoader.Column("Label", DataKind.R4, 0),
                new TextLoader.Column("SepalLength", DataKind.R4, 1),
                new TextLoader.Column("SepalWidth", DataKind.R4, 2),
                new TextLoader.Column("PetalLength", DataKind.R4, 3),
                new TextLoader.Column("PetalWidth", DataKind.R4, 4),
            },
                                        hasHeader: true
                                        );

            IDataView testData       = reader.Read(_dataPath);
            IDataView scoredTestData = _trainedModel.Transform(testData);
            var       evaluator      = new MultiClassClassifierEvaluator(_env, new MultiClassClassifierEvaluator.Arguments());

            _metrics = evaluator.Evaluate(scoredTestData, DefaultColumnNames.Label, DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel);

            _batches = new IrisData[_batchSizes.Length][];
            for (int i = 0; i < _batches.Length; i++)
            {
                var batch = new IrisData[_batchSizes[i]];
                for (int bi = 0; bi < batch.Length; bi++)
                {
                    batch[bi] = _example;
                }
                _batches[i] = batch;
            }
        }
Пример #2
0
        static void Main(string[] args)
        {
            MLContext mlContext = new MLContext();

            ClassGenerator classGenerator = new ClassGenerator("GeneratedIris", "CustomClass");

            classGenerator.AddField("SepalLength", typeof(float), System.CodeDom.MemberAttributes.Public);
            classGenerator.AddField("SepalWidth", typeof(float), System.CodeDom.MemberAttributes.Public);
            classGenerator.AddField("PetalLength", typeof(float), System.CodeDom.MemberAttributes.Public);
            classGenerator.AddField("PetalWidth", typeof(float), System.CodeDom.MemberAttributes.Public);
            classGenerator.AddField("Label", typeof(string), System.CodeDom.MemberAttributes.Public);
            classGenerator.Compile();

            List <object> generatedDataSet = new List <object>();

            dataset.ToList().ForEach((d) =>
            {
                generatedDataSet.Add(GetDynamicClass(d, classGenerator.GetInstance()));
            });

            var instance = classGenerator.GetInstance().GetType();
            DataViewGenerator listGenerator = new DataViewGenerator("ListIris", "CustomGenerator", instance, classGenerator.NamespaceName);
            var type       = listGenerator.GeneratorType;
            var methodInfo = type.GetMethod("GetDataView");
            var dataView   = methodInfo.Invoke(null, new object[] { generatedDataSet.ToList() });

            IDataView trainingDataView = (IDataView)dataView;

            trainingDataView.Schema.ToList().Add(new DataViewSchema.Column());

            var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
                           .Append(mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"))
                           .AppendCacheCheckpoint(mlContext)
                           .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(labelColumnName: "Label", featureColumnName: "Features"))
                           .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));


            TransformerChain <Microsoft.ML.Transforms.KeyToValueMappingTransformer> model = pipeline.Fit(trainingDataView);

            var prediction = model.CreatePredictionEngine <IrisData, IrisPrediction>(mlContext).Predict(
                new IrisData()
            {
                SepalLength = 5.9f,
                SepalWidth  = 3.0f,
                PetalLength = 5.1f,
                PetalWidth  = 1.8f,
            });

            Console.WriteLine(prediction.PredictedLabels);

            Console.ReadLine();
        }
Пример #3
0
        private void LoadModel()
        {
            // Here the model is being loaded from a file.  We could also embed the model in an
            // assembly as a resource.  This would then allow us to update the model via a NuGet
            // package.
            using (FileStream fileStream = new FileStream("test-model.zip", FileMode.Open, FileAccess.Read, FileShare.Read))
            {
                // The MLContext is the starting point for all ML things using ML.Net.
                MLContext context = new MLContext();

                // Build the model from the data contained within the zip file.
                TransformerChain <ITransformer> model = TransformerChain.LoadFrom(context, fileStream);

                // Create the predictor we'll use whenever the user clicks a button.
                _predictor = model.CreatePredictionEngine <SentimentData, SentimentPrediction>(context);
            }
        }
Пример #4
0
        public void SetupPredictBenchmarks()
        {
            _trainedModel     = Train(_dataPath);
            _predictionEngine = _trainedModel.CreatePredictionEngine <IrisData, IrisPrediction>(mlContext);
            _consumer.Consume(_predictionEngine.Predict(_example));

            // Create text loader.
            var options = new TextLoader.Options()
            {
                Columns = new[]
                {
                    new TextLoader.Column("Label", DataKind.Single, 0),
                    new TextLoader.Column("SepalLength", DataKind.Single, 1),
                    new TextLoader.Column("SepalWidth", DataKind.Single, 2),
                    new TextLoader.Column("PetalLength", DataKind.Single, 3),
                    new TextLoader.Column("PetalWidth", DataKind.Single, 4),
                },
                HasHeader = true,
            };
            var loader = new TextLoader(mlContext, options: options);

            IDataView testData       = loader.Load(_dataPath);
            IDataView scoredTestData = _trainedModel.Transform(testData);
            var       evaluator      = new MulticlassClassificationEvaluator(mlContext, new MulticlassClassificationEvaluator.Arguments());

            _metrics = evaluator.Evaluate(scoredTestData, DefaultColumnNames.Label, DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel);

            _batches = new IrisData[_batchSizes.Length][];
            for (int i = 0; i < _batches.Length; i++)
            {
                var batch = new IrisData[_batchSizes[i]];
                for (int bi = 0; bi < batch.Length; bi++)
                {
                    batch[bi] = _example;
                }
                _batches[i] = batch;
            }
        }
Пример #5
0
        static void Main(string[] args)
        {
            // Download the dataset if it doesn't exist.
            if (!File.Exists(TrainDataPath))
            {
                using (var client = new WebClient())
                {
                    //The code below will download a dataset from a third-party, UCI (link), and may be governed by separate third-party terms.
                    //By proceeding, you agree to those separate terms.
                    client.DownloadFile("https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip", "spam.zip");
                }

                ZipFile.ExtractToDirectory("spam.zip", DataDirectoryPath);
            }

            // Set up the MLContext, which is a catalog of components in ML.NET.
            MLContext mlContext = new MLContext();

            // Specify the schema for spam data and read it into DataView.
            var data = mlContext.Data.ReadFromTextFile <SpamInput>(path: TrainDataPath, hasHeader: true, separatorChar: '\t');

            // Create the estimator which converts the text label to boolean, featurizes the text, and adds a linear trainer.
            var dataProcessPipeLine = mlContext.Transforms.CustomMapping <MyInput, MyOutput>(mapAction: MyLambda.MyAction, contractName: "MyLambda")
                                      .Append(mlContext.Transforms.Text.FeaturizeText(outputColumnName: DefaultColumnNames.Features, inputColumnName: nameof(SpamInput.Message)));

            //Create the training pipeline
            Console.WriteLine("=============== Training the model ===============");
            var trainingPipeLine = dataProcessPipeLine.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());

            // Evaluate the model using cross-validation.
            // Cross-validation splits our dataset into 'folds', trains a model on some folds and
            // evaluates it on the remaining fold. We are using 5 folds so we get back 5 sets of scores.
            // Let's compute the average AUC, which should be between 0.5 and 1 (higher is better).
            Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
            var crossValidationResults = mlContext.BinaryClassification.CrossValidate(data: data, estimator: trainingPipeLine, numFolds: 5);
            var aucs = crossValidationResults.Select(r => r.metrics.Auc);

            Console.WriteLine("The AUC is {0}", aucs.Average());

            // Now let's train a model on the full dataset to help us get better results
            var model = trainingPipeLine.Fit(data);

            // The dataset we have is skewed, as there are many more non-spam messages than spam messages.
            // While our model is relatively good at detecting the difference, this skewness leads it to always
            // say the message is not spam. We deal with this by lowering the threshold of the predictor. In reality,
            // it is useful to look at the precision-recall curve to identify the best possible threshold.
            var inPipe          = new TransformerChain <ITransformer>(model.Take(model.Count() - 1).ToArray());
            var lastTransformer = new BinaryPredictionTransformer <IPredictorProducing <float> >(mlContext, model.LastTransformer.Model, inPipe.GetOutputSchema(data.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f, thresholdColumn: DefaultColumnNames.Probability);

            ITransformer[] parts = model.ToArray();
            parts[parts.Length - 1] = lastTransformer;
            ITransformer newModel = new TransformerChain <ITransformer>(parts);

            // Create a PredictionFunction from our model
            var predictor = newModel.CreatePredictionEngine <SpamInput, SpamPrediction>(mlContext);

            Console.WriteLine("=============== Predictions for below data===============");
            // Test a few examples
            ClassifyMessage(predictor, "That's a great idea. It should work.");
            ClassifyMessage(predictor, "free medicine winner! congratulations");
            ClassifyMessage(predictor, "Yes we should meet over the weekend!");
            ClassifyMessage(predictor, "you win pills and free entry vouchers");

            Console.WriteLine("=============== End of process, hit any key to finish =============== ");
            Console.ReadLine();
        }