private static IDataView GetMulticlassMetrics( IHostEnvironment env, IPredictor predictor, RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new MulticlassPredictionTransformer <IPredictorProducing <VBuffer <float> > >( env, predictor as IPredictorProducing <VBuffer <float> >, roleMappedData.Data.Schema, featureColumnName, labelColumnName); var multiclassCatalog = new MulticlassClassificationCatalog(env); var permutationMetrics = multiclassCatalog .PermutationFeatureImportance(pred, roleMappedData.Data, labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); var slotNames = GetSlotNames(roleMappedData.Schema); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); List <MulticlassMetrics> metrics = new List <MulticlassMetrics>(); for (int i = 0; i < permutationMetrics.Length; i++) { if (string.IsNullOrWhiteSpace(slotNames[i])) { continue; } var pMetric = permutationMetrics[i]; metrics.Add(new MulticlassMetrics { FeatureName = slotNames[i], MacroAccuracy = pMetric.MacroAccuracy.Mean, MacroAccuracyStdErr = pMetric.MacroAccuracy.StandardError, MicroAccuracy = pMetric.MicroAccuracy.Mean, MicroAccuracyStdErr = pMetric.MicroAccuracy.StandardError, LogLoss = pMetric.LogLoss.Mean, LogLossStdErr = pMetric.LogLoss.StandardError, LogLossReduction = pMetric.LogLossReduction.Mean, LogLossReductionStdErr = pMetric.LogLossReduction.StandardError, TopKAccuracy = pMetric.TopKAccuracy.Mean, TopKAccuracyStdErr = pMetric.TopKAccuracy.StandardError, PerClassLogLoss = pMetric.PerClassLogLoss.Select(x => x.Mean).ToArray(), PerClassLogLossStdErr = pMetric.PerClassLogLoss.Select(x => x.StandardError).ToArray() });; } // Convert unknown size vectors to known size. var metric = metrics.First(); SchemaDefinition schema = SchemaDefinition.Create(typeof(MulticlassMetrics)); ConvertVectorToKnownSize(nameof(metric.PerClassLogLoss), metric.PerClassLogLoss.Length, ref schema); ConvertVectorToKnownSize(nameof(metric.PerClassLogLossStdErr), metric.PerClassLogLossStdErr.Length, ref schema); var dataOps = new DataOperationsCatalog(env); var result = dataOps.LoadFromEnumerable(metrics, schema); return(result); }