Beispiel #1
0
        public void SdcaMulticlass()
        {
            var env        = new ConsoleEnvironment(seed: 0);
            var dataPath   = GetDataPath(TestDatasets.iris.trainFilename);
            var dataSource = new MultiFileSource(dataPath);

            var ctx    = new MulticlassClassificationContext(env);
            var reader = TextLoader.CreateReader(env,
                                                 c => (label: c.LoadText(0), features: c.LoadFloat(1, 4)));

            MulticlassLogisticRegressionPredictor pred = null;

            var loss = new HingeLoss(new HingeLoss.Arguments()
            {
                Margin = 1
            });

            // With a custom loss function we no longer get calibrated predictions.
            var est = reader.MakeNewEstimator()
                      .Append(r => (label: r.label.ToKey(), r.features))
                      .Append(r => (r.label, preds: ctx.Trainers.Sdca(
                                        r.label,
                                        r.features,
                                        maxIterations: 2,
                                        loss: loss, onFit: p => pred = p)));

            var pipe = reader.Append(est);

            Assert.Null(pred);
            var model = pipe.Fit(dataSource);

            Assert.NotNull(pred);
            VBuffer <float>[] weights = default;
            pred.GetWeights(ref weights, out int n);
            Assert.True(n == 3 && n == weights.Length);
            foreach (var w in weights)
            {
                Assert.True(w.Length == 4);
            }

            var biases = pred.GetBiases();

            Assert.True(biases.Count() == 3);

            var data = model.Read(dataSource);

            // Just output some data on the schema for fun.
            var schema = data.AsDynamic.Schema;

            for (int c = 0; c < schema.ColumnCount; ++c)
            {
                Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");
            }

            var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2);

            Assert.True(metrics.LogLoss > 0);
            Assert.True(metrics.TopKAccuracy > 0);
        }
Beispiel #2
0
        public void MulticlassLogisticRegression()
        {
            var env        = new MLContext(seed: 0);
            var dataPath   = GetDataPath(TestDatasets.iris.trainFilename);
            var dataSource = new MultiFileSource(dataPath);

            var ctx    = new MulticlassClassificationContext(env);
            var reader = TextLoader.CreateReader(env,
                                                 c => (label: c.LoadText(0), features: c.LoadFloat(1, 4)));

            MulticlassLogisticRegressionPredictor pred = null;

            // With a custom loss function we no longer get calibrated predictions.
            var est = reader.MakeNewEstimator()
                      .Append(r => (label: r.label.ToKey(), r.features))
                      .Append(r => (r.label, preds: ctx.Trainers.MultiClassLogisticRegression(
                                        r.label,
                                        r.features, onFit: p => pred = p,
                                        advancedSettings: s => s.NumThreads = 1)));

            var pipe = reader.Append(est);

            Assert.Null(pred);
            var model = pipe.Fit(dataSource);

            Assert.NotNull(pred);
            VBuffer <float>[] weights = default;
            pred.GetWeights(ref weights, out int n);
            Assert.True(n == 3 && n == weights.Length);
            foreach (var w in weights)
            {
                Assert.True(w.Length == 4);
            }

            var data = model.Read(dataSource);

            // Just output some data on the schema for fun.
            var schema = data.AsDynamic.Schema;

            for (int c = 0; c < schema.Count; ++c)
            {
                Console.WriteLine($"{schema[c].Name}, {schema[c].Type}");
            }

            var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2);

            Assert.True(metrics.LogLoss > 0);
            Assert.True(metrics.TopKAccuracy > 0);
        }
Beispiel #3
0
        private void TrainAndInspectWeights(string dataPath)
        {
            // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
            // as a catalog of available operations and as the source of randomness.
            var mlContext = new MLContext();

            // Step one: read the data as an IDataView.
            // First, we define the reader: specify the data columns and where to find them in the text file.
            var reader = mlContext.Data.TextReader(ctx => (
                                                       // The four features of the Iris dataset.
                                                       SepalLength: ctx.LoadFloat(0),
                                                       SepalWidth: ctx.LoadFloat(1),
                                                       PetalLength: ctx.LoadFloat(2),
                                                       PetalWidth: ctx.LoadFloat(3),
                                                       // Label: kind of iris.
                                                       Label: ctx.LoadText(4)
                                                       ),
                                                   // Default separator is tab, but the dataset has comma.
                                                   separator: ',');

            // Retrieve the training data.
            var trainData = reader.Read(dataPath);

            // This is the predictor ('weights collection') that we will train.
            MulticlassLogisticRegressionPredictor predictor = null;
            // And these are the normalizer scales that we will learn.
            ImmutableArray <float> normScales;
            // Build the training pipeline.
            var learningPipeline = reader.MakeNewEstimator()
                                   .Append(r => (
                                               r.Label,
                                               // Concatenate all the features together into one column 'Features'.
                                               Features: r.SepalLength.ConcatWith(r.SepalWidth, r.PetalLength, r.PetalWidth)))
                                   .Append(r => (
                                               r.Label,
                                               // Normalize (rescale) the features to be between -1 and 1.
                                               Features: r.Features.Normalize(
                                                   // When the normalizer is trained, the below delegate is going to be called.
                                                   // We use it to memorize the scales.
                                                   onFit: (scales, offsets) => normScales = scales)))
                                   .Append(r => (
                                               r.Label,
                                               // Train the multi-class SDCA model to predict the label using features.
                                               // Note that the label is a text, so it needs to be converted to key using 'ToKey' estimator.
                                               Predictions: mlContext.MulticlassClassification.Trainers.Sdca(r.Label.ToKey(), r.Features,
                                                                                                             // When the model is trained, the below delegate is going to be called.
                                                                                                             // We use that to memorize the predictor object.
                                                                                                             onFit: p => predictor = p)));

            // Train the model. During this call our 'onFit' delegate will be invoked,
            // and our 'predictor' will be set.
            var model = learningPipeline.Fit(trainData);

            // Now we can use 'predictor' to look at the weights.
            // 'weights' will be an array of weight vectors, one vector per class.
            // Our problem has 3 classes, so numClasses will be 3, and weights will contain
            // 3 vectors (of 4 values each).
            VBuffer <float>[] weights = null;
            predictor.GetWeights(ref weights, out int numClasses);

            // Similarly we can also inspect the biases for the 3 classes.
            var biases = predictor.GetBiases();

            // Inspect the normalizer scales.
            Console.WriteLine(string.Join(" ", normScales));
        }