public void SdcaMulticlass() { var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.iris.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new MulticlassClassificationContext(env); var reader = TextLoaderStatic.CreateReader(env, c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); MulticlassLogisticRegressionModelParameters pred = null; var loss = new HingeLoss(1); // With a custom loss function we no longer get calibrated predictions. var est = reader.MakeNewEstimator() .Append(r => (label: r.label.ToKey(), r.features)) .Append(r => (r.label, preds: ctx.Trainers.Sdca( r.label, r.features, maxIterations: 2, loss: loss, onFit: p => pred = p))); var pipe = reader.Append(est); Assert.Null(pred); var model = pipe.Fit(dataSource); Assert.NotNull(pred); VBuffer <float>[] weights = default; pred.GetWeights(ref weights, out int n); Assert.True(n == 3 && n == weights.Length); foreach (var w in weights) { Assert.True(w.Length == 4); } var biases = pred.GetBiases(); Assert.True(biases.Count() == 3); var data = model.Read(dataSource); // Just output some data on the schema for fun. var schema = data.AsDynamic.Schema; for (int c = 0; c < schema.Count; ++c) { Console.WriteLine($"{schema[c].Name}, {schema[c].Type}"); } var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2); Assert.True(metrics.LogLoss > 0); Assert.True(metrics.TopKAccuracy > 0); }
private void TrainAndInspectWeights(string dataPath) { // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(); // Step one: read the data as an IDataView. // First, we define the loader: specify the data columns and where to find them in the text file. var loader = mlContext.Data.CreateTextLoader(ctx => ( // The four features of the Iris dataset. SepalLength: ctx.LoadFloat(0), SepalWidth: ctx.LoadFloat(1), PetalLength: ctx.LoadFloat(2), PetalWidth: ctx.LoadFloat(3), // Label: kind of iris. Label: ctx.LoadText(4) ), // Default separator is tab, but the dataset has comma. separator: ','); // Retrieve the training data. var trainData = loader.Load(dataPath); // This is the predictor ('weights collection') that we will train. MulticlassLogisticRegressionModelParameters predictor = null; // And these are the normalizer scales that we will learn. ImmutableArray <float> normScales; // Build the training pipeline. var pipeline = loader.MakeNewEstimator() .Append(r => ( r.Label, // Concatenate all the features together into one column 'Features'. Features: r.SepalLength.ConcatWith(r.SepalWidth, r.PetalLength, r.PetalWidth))) .Append(r => ( r.Label, // Normalize (rescale) the features to be between -1 and 1. Features: r.Features.Normalize( // When the normalizer is trained, the below delegate is going to be called. // We use it to memorize the scales. onFit: (scales, offsets) => normScales = scales))) // Cache data used in memory because the subsequently trainer needs to access the data multiple times. .AppendCacheCheckpoint() .Append(r => ( r.Label, // Train the multi-class SDCA model to predict the label using features. // Note that the label is a text, so it needs to be converted to key using 'ToKey' estimator. Predictions: mlContext.MulticlassClassification.Trainers.Sdca(r.Label.ToKey(), r.Features, // When the model is trained, the below delegate is going to be called. // We use that to memorize the predictor object. onFit: p => predictor = p))); // Train the model. During this call our 'onFit' delegate will be invoked, // and our 'predictor' will be set. var model = pipeline.Fit(trainData); // Now we can use 'predictor' to look at the weights. // 'weights' will be an array of weight vectors, one vector per class. // Our problem has 3 classes, so numClasses will be 3, and weights will contain // 3 vectors (of 4 values each). VBuffer <float>[] weights = null; predictor.GetWeights(ref weights, out int numClasses); // Similarly we can also inspect the biases for the 3 classes. var biases = predictor.GetBiases(); // Inspect the normalizer scales. Console.WriteLine(string.Join(" ", normScales)); }