public SignalClassifierController(string frameSize, string sensorType, string[] datasets, string[] labels) { mlContext = new MLContext(); categories = labels; var reader = getFrameReader(frameSize, sensorType); var trainingDataView = reader.Load(datasets); var split = mlContext.Data.TrainTestSplit(trainingDataView, testFraction: 0.2); estimatorPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label") .Append(mlContext.Transforms.NormalizeMinMax("readings", fixZero: true)) .Append(mlContext.MulticlassClassification.Trainers .OneVersusAll(mlContext.BinaryClassification.Trainers .FastTree(featureColumnName: "readings"))); // .Append(mlContext.MulticlassClassification.Trainers // .NaiveBayes(featureColumnName: "readings")); // .Append(mlContext.MulticlassClassification.Trainers // .OneVersusAll(mlContext.BinaryClassification.Trainers // .LbfgsLogisticRegression(featureColumnName: "readings"))); // .Append(mlContext.MulticlassClassification.Trainers // .OneVersusAll(mlContext.BinaryClassification.Trainers // .LdSvm(featureColumnName: "readings"))); transformer = estimatorPipeline.Fit(split.TrainSet); // var OVAEstimator = mlContext.MulticlassClassification.Trainers // .OneVersusAll(mlContext.BinaryClassification.Trainers // .LbfgsLogisticRegression(featureColumnName: "readings")); // var OVAEstimator = mlContext.MulticlassClassification.Trainers // .OneVersusAll(mlContext.BinaryClassification.Trainers // .LdSvm(featureColumnName: "readings")); // var NBEstimator = mlContext.MulticlassClassification.Trainers // .NaiveBayes(featureColumnName: "readings"); var OVAEstimator = mlContext.MulticlassClassification.Trainers .OneVersusAll(mlContext.BinaryClassification.Trainers .FastTree(featureColumnName: "readings")); var transformedTrainingData = transformer.Transform(split.TrainSet); model = OVAEstimator.Fit(transformedTrainingData); // model = NBEstimator.Fit(transformedTrainingData); Console.WriteLine("Model fitted"); var transformedTestData = transformer.Transform(split.TestSet); var testPredictions = model.Transform(transformedTestData); Console.WriteLine(mlContext.MulticlassClassification.Evaluate(testPredictions).ConfusionMatrix.GetFormattedConfusionTable()); }
private static IDataView GetMulticlassMetrics( IHostEnvironment env, IPredictor predictor, RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new MulticlassPredictionTransformer <IPredictorProducing <VBuffer <float> > >( env, predictor as IPredictorProducing <VBuffer <float> >, roleMappedData.Data.Schema, featureColumnName, labelColumnName); var multiclassCatalog = new MulticlassClassificationCatalog(env); var permutationMetrics = multiclassCatalog .PermutationFeatureImportance(pred, roleMappedData.Data, labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); var slotNames = GetSlotNames(roleMappedData.Schema); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); List <MulticlassMetrics> metrics = new List <MulticlassMetrics>(); for (int i = 0; i < permutationMetrics.Length; i++) { if (string.IsNullOrWhiteSpace(slotNames[i])) { continue; } var pMetric = permutationMetrics[i]; metrics.Add(new MulticlassMetrics { FeatureName = slotNames[i], MacroAccuracy = pMetric.MacroAccuracy.Mean, MacroAccuracyStdErr = pMetric.MacroAccuracy.StandardError, MicroAccuracy = pMetric.MicroAccuracy.Mean, MicroAccuracyStdErr = pMetric.MicroAccuracy.StandardError, LogLoss = pMetric.LogLoss.Mean, LogLossStdErr = pMetric.LogLoss.StandardError, LogLossReduction = pMetric.LogLossReduction.Mean, LogLossReductionStdErr = pMetric.LogLossReduction.StandardError, TopKAccuracy = pMetric.TopKAccuracy.Mean, TopKAccuracyStdErr = pMetric.TopKAccuracy.StandardError, PerClassLogLoss = pMetric.PerClassLogLoss.Select(x => x.Mean).ToArray(), PerClassLogLossStdErr = pMetric.PerClassLogLoss.Select(x => x.StandardError).ToArray() });; } // Convert unknown size vectors to known size. var metric = metrics.First(); SchemaDefinition schema = SchemaDefinition.Create(typeof(MulticlassMetrics)); ConvertVectorToKnownSize(nameof(metric.PerClassLogLoss), metric.PerClassLogLoss.Length, ref schema); ConvertVectorToKnownSize(nameof(metric.PerClassLogLossStdErr), metric.PerClassLogLossStdErr.Length, ref schema); var dataOps = new DataOperationsCatalog(env); var result = dataOps.LoadFromEnumerable(metrics, schema); return(result); }