internal static IDataView GetMetrics( IHostEnvironment env, IPredictor predictor, RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { IDataView result; if (predictor.PredictionKind == PredictionKind.BinaryClassification) { result = GetBinaryMetrics(env, predictor, roleMappedData, input); } else if (predictor.PredictionKind == PredictionKind.MulticlassClassification) { result = GetMulticlassMetrics(env, predictor, roleMappedData, input); } else if (predictor.PredictionKind == PredictionKind.Regression) { result = GetRegressionMetrics(env, predictor, roleMappedData, input); } else if (predictor.PredictionKind == PredictionKind.Ranking) { result = GetRankingMetrics(env, predictor, roleMappedData, input); } else { throw Contracts.Except( "Unsupported predictor type. Predictor must be binary classifier, " + "multiclass classifier, regressor, or ranker."); } return(result); }
private static IDataView GetRankingMetrics( IHostEnvironment env, IPredictor predictor, RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var groupIdColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Group.Value).First().Value; var pred = new RankingPredictionTransformer <IPredictorProducing <float> >( env, predictor as IPredictorProducing <float>, roleMappedData.Data.Schema, featureColumnName); var rankingCatalog = new RankingCatalog(env); var permutationMetrics = rankingCatalog .PermutationFeatureImportance(pred, roleMappedData.Data, labelColumnName: labelColumnName, rowGroupColumnName: groupIdColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); var slotNames = GetSlotNames(roleMappedData.Schema); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); List <RankingMetrics> metrics = new List <RankingMetrics>(); for (int i = 0; i < permutationMetrics.Length; i++) { if (string.IsNullOrWhiteSpace(slotNames[i])) { continue; } var pMetric = permutationMetrics[i]; metrics.Add(new RankingMetrics { FeatureName = slotNames[i], DiscountedCumulativeGains = pMetric.DiscountedCumulativeGains.Select(x => x.Mean).ToArray(), DiscountedCumulativeGainsStdErr = pMetric.DiscountedCumulativeGains.Select(x => x.StandardError).ToArray(), NormalizedDiscountedCumulativeGains = pMetric.NormalizedDiscountedCumulativeGains.Select(x => x.Mean).ToArray(), NormalizedDiscountedCumulativeGainsStdErr = pMetric.NormalizedDiscountedCumulativeGains.Select(x => x.StandardError).ToArray() }); } // Convert unknown size vectors to known size. var metric = metrics.First(); SchemaDefinition schema = SchemaDefinition.Create(typeof(RankingMetrics)); ConvertVectorToKnownSize(nameof(metric.DiscountedCumulativeGains), metric.DiscountedCumulativeGains.Length, ref schema); ConvertVectorToKnownSize(nameof(metric.NormalizedDiscountedCumulativeGains), metric.NormalizedDiscountedCumulativeGains.Length, ref schema); ConvertVectorToKnownSize(nameof(metric.DiscountedCumulativeGainsStdErr), metric.DiscountedCumulativeGainsStdErr.Length, ref schema); ConvertVectorToKnownSize(nameof(metric.NormalizedDiscountedCumulativeGainsStdErr), metric.NormalizedDiscountedCumulativeGainsStdErr.Length, ref schema); var dataOps = new DataOperationsCatalog(env); var result = dataOps.LoadFromEnumerable(metrics, schema); return(result); }
private static IDataView GetRegressionMetrics( IHostEnvironment env, IPredictor predictor, RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new RegressionPredictionTransformer <IPredictorProducing <float> >( env, predictor as IPredictorProducing <float>, roleMappedData.Data.Schema, featureColumnName); var regressionCatalog = new RegressionCatalog(env); var permutationMetrics = regressionCatalog .PermutationFeatureImportance(pred, roleMappedData.Data, labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); var slotNames = GetSlotNames(roleMappedData.Schema); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); List <RegressionMetrics> metrics = new List <RegressionMetrics>(); for (int i = 0; i < permutationMetrics.Length; i++) { if (string.IsNullOrWhiteSpace(slotNames[i])) { continue; } var pMetric = permutationMetrics[i]; metrics.Add(new RegressionMetrics { FeatureName = slotNames[i], MeanAbsoluteError = pMetric.MeanAbsoluteError.Mean, MeanAbsoluteErrorStdErr = pMetric.MeanAbsoluteError.StandardError, MeanSquaredError = pMetric.MeanSquaredError.Mean, MeanSquaredErrorStdErr = pMetric.MeanSquaredError.StandardError, RootMeanSquaredError = pMetric.RootMeanSquaredError.Mean, RootMeanSquaredErrorStdErr = pMetric.RootMeanSquaredError.StandardError, LossFunction = pMetric.LossFunction.Mean, LossFunctionStdErr = pMetric.LossFunction.StandardError, RSquared = pMetric.RSquared.Mean, RSquaredStdErr = pMetric.RSquared.StandardError }); } var dataOps = new DataOperationsCatalog(env); var result = dataOps.LoadFromEnumerable(metrics); return(result); }
private static IDataView GetBinaryMetrics( IHostEnvironment env, IPredictor predictor, RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new BinaryPredictionTransformer <IPredictorProducing <float> >( env, predictor as IPredictorProducing <float>, roleMappedData.Data.Schema, featureColumnName); var binaryCatalog = new BinaryClassificationCatalog(env); var permutationMetrics = binaryCatalog .PermutationFeatureImportance(pred, roleMappedData.Data, labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); var slotNames = GetSlotNames(roleMappedData.Schema); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); List <BinaryMetrics> metrics = new List <BinaryMetrics>(); for (int i = 0; i < permutationMetrics.Length; i++) { if (string.IsNullOrWhiteSpace(slotNames[i])) { continue; } var pMetric = permutationMetrics[i]; metrics.Add(new BinaryMetrics { FeatureName = slotNames[i], AreaUnderRocCurve = pMetric.AreaUnderRocCurve.Mean, AreaUnderRocCurveStdErr = pMetric.AreaUnderRocCurve.StandardError, Accuracy = pMetric.Accuracy.Mean, AccuracyStdErr = pMetric.Accuracy.StandardError, PositivePrecision = pMetric.PositivePrecision.Mean, PositivePrecisionStdErr = pMetric.PositivePrecision.StandardError, PositiveRecall = pMetric.PositiveRecall.Mean, PositiveRecallStdErr = pMetric.PositiveRecall.StandardError, NegativePrecision = pMetric.NegativePrecision.Mean, NegativePrecisionStdErr = pMetric.NegativePrecision.StandardError, NegativeRecall = pMetric.NegativeRecall.Mean, NegativeRecallStdErr = pMetric.NegativeRecall.StandardError, F1Score = pMetric.F1Score.Mean, F1ScoreStdErr = pMetric.F1Score.StandardError, AreaUnderPrecisionRecallCurve = pMetric.AreaUnderPrecisionRecallCurve.Mean, AreaUnderPrecisionRecallCurveStdErr = pMetric.AreaUnderPrecisionRecallCurve.StandardError }); } var dataOps = new DataOperationsCatalog(env); var result = dataOps.LoadFromEnumerable(metrics); return(result); }
public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IHostEnvironment env, PermutationFeatureImportanceArguments input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("Pfi"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); input.PredictorModel.PrepareData(env, input.Data, out RoleMappedData roleMappedData, out IPredictor predictor); Contracts.Assert(predictor != null, "No predictor found in model"); IDataView result = PermutationFeatureImportanceUtils.GetMetrics(env, predictor, roleMappedData, input); return(new PermutationFeatureImportanceOutput { Metrics = result }); }
private static IDataView GetMulticlassMetrics( IHostEnvironment env, IPredictor predictor, RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new MulticlassPredictionTransformer <IPredictorProducing <VBuffer <float> > >( env, predictor as IPredictorProducing <VBuffer <float> >, roleMappedData.Data.Schema, featureColumnName, labelColumnName); var multiclassCatalog = new MulticlassClassificationCatalog(env); var permutationMetrics = multiclassCatalog .PermutationFeatureImportance(pred, roleMappedData.Data, labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); var slotNames = GetSlotNames(roleMappedData.Schema); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); List <MulticlassMetrics> metrics = new List <MulticlassMetrics>(); for (int i = 0; i < permutationMetrics.Length; i++) { if (string.IsNullOrWhiteSpace(slotNames[i])) { continue; } var pMetric = permutationMetrics[i]; metrics.Add(new MulticlassMetrics { FeatureName = slotNames[i], MacroAccuracy = pMetric.MacroAccuracy.Mean, MacroAccuracyStdErr = pMetric.MacroAccuracy.StandardError, MicroAccuracy = pMetric.MicroAccuracy.Mean, MicroAccuracyStdErr = pMetric.MicroAccuracy.StandardError, LogLoss = pMetric.LogLoss.Mean, LogLossStdErr = pMetric.LogLoss.StandardError, LogLossReduction = pMetric.LogLossReduction.Mean, LogLossReductionStdErr = pMetric.LogLossReduction.StandardError, TopKAccuracy = pMetric.TopKAccuracy.Mean, TopKAccuracyStdErr = pMetric.TopKAccuracy.StandardError, PerClassLogLoss = pMetric.PerClassLogLoss.Select(x => x.Mean).ToArray(), PerClassLogLossStdErr = pMetric.PerClassLogLoss.Select(x => x.StandardError).ToArray() });; } // Convert unknown size vectors to known size. var metric = metrics.First(); SchemaDefinition schema = SchemaDefinition.Create(typeof(MulticlassMetrics)); ConvertVectorToKnownSize(nameof(metric.PerClassLogLoss), metric.PerClassLogLoss.Length, ref schema); ConvertVectorToKnownSize(nameof(metric.PerClassLogLossStdErr), metric.PerClassLogLossStdErr.Length, ref schema); var dataOps = new DataOperationsCatalog(env); var result = dataOps.LoadFromEnumerable(metrics, schema); return(result); }