internal static IDataView GetMetrics(
            IHostEnvironment env,
            IPredictor predictor,
            RoleMappedData roleMappedData,
            PermutationFeatureImportanceArguments input)
        {
            IDataView result;

            if (predictor.PredictionKind == PredictionKind.BinaryClassification)
            {
                result = GetBinaryMetrics(env, predictor, roleMappedData, input);
            }
            else if (predictor.PredictionKind == PredictionKind.MulticlassClassification)
            {
                result = GetMulticlassMetrics(env, predictor, roleMappedData, input);
            }
            else if (predictor.PredictionKind == PredictionKind.Regression)
            {
                result = GetRegressionMetrics(env, predictor, roleMappedData, input);
            }
            else if (predictor.PredictionKind == PredictionKind.Ranking)
            {
                result = GetRankingMetrics(env, predictor, roleMappedData, input);
            }
            else
            {
                throw Contracts.Except(
                          "Unsupported predictor type. Predictor must be binary classifier, " +
                          "multiclass classifier, regressor, or ranker.");
            }

            return(result);
        }
        private static IDataView GetRankingMetrics(
            IHostEnvironment env,
            IPredictor predictor,
            RoleMappedData roleMappedData,
            PermutationFeatureImportanceArguments input)
        {
            var roles             = roleMappedData.Schema.GetColumnRoleNames();
            var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value;
            var labelColumnName   = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value;
            var groupIdColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Group.Value).First().Value;
            var pred = new RankingPredictionTransformer <IPredictorProducing <float> >(
                env, predictor as IPredictorProducing <float>, roleMappedData.Data.Schema, featureColumnName);
            var rankingCatalog     = new RankingCatalog(env);
            var permutationMetrics = rankingCatalog
                                     .PermutationFeatureImportance(pred,
                                                                   roleMappedData.Data,
                                                                   labelColumnName: labelColumnName,
                                                                   rowGroupColumnName: groupIdColumnName,
                                                                   useFeatureWeightFilter: input.UseFeatureWeightFilter,
                                                                   numberOfExamplesToUse: input.NumberOfExamplesToUse,
                                                                   permutationCount: input.PermutationCount);

            var slotNames = GetSlotNames(roleMappedData.Schema);

            Contracts.Assert(slotNames.Length == permutationMetrics.Length,
                             "Mismatch between number of feature slots and number of features permuted.");

            List <RankingMetrics> metrics = new List <RankingMetrics>();

            for (int i = 0; i < permutationMetrics.Length; i++)
            {
                if (string.IsNullOrWhiteSpace(slotNames[i]))
                {
                    continue;
                }
                var pMetric = permutationMetrics[i];
                metrics.Add(new RankingMetrics
                {
                    FeatureName = slotNames[i],
                    DiscountedCumulativeGains                 = pMetric.DiscountedCumulativeGains.Select(x => x.Mean).ToArray(),
                    DiscountedCumulativeGainsStdErr           = pMetric.DiscountedCumulativeGains.Select(x => x.StandardError).ToArray(),
                    NormalizedDiscountedCumulativeGains       = pMetric.NormalizedDiscountedCumulativeGains.Select(x => x.Mean).ToArray(),
                    NormalizedDiscountedCumulativeGainsStdErr = pMetric.NormalizedDiscountedCumulativeGains.Select(x => x.StandardError).ToArray()
                });
            }

            // Convert unknown size vectors to known size.
            var metric = metrics.First();
            SchemaDefinition schema = SchemaDefinition.Create(typeof(RankingMetrics));

            ConvertVectorToKnownSize(nameof(metric.DiscountedCumulativeGains), metric.DiscountedCumulativeGains.Length, ref schema);
            ConvertVectorToKnownSize(nameof(metric.NormalizedDiscountedCumulativeGains), metric.NormalizedDiscountedCumulativeGains.Length, ref schema);
            ConvertVectorToKnownSize(nameof(metric.DiscountedCumulativeGainsStdErr), metric.DiscountedCumulativeGainsStdErr.Length, ref schema);
            ConvertVectorToKnownSize(nameof(metric.NormalizedDiscountedCumulativeGainsStdErr), metric.NormalizedDiscountedCumulativeGainsStdErr.Length, ref schema);

            var dataOps = new DataOperationsCatalog(env);
            var result  = dataOps.LoadFromEnumerable(metrics, schema);

            return(result);
        }
        private static IDataView GetRegressionMetrics(
            IHostEnvironment env,
            IPredictor predictor,
            RoleMappedData roleMappedData,
            PermutationFeatureImportanceArguments input)
        {
            var roles             = roleMappedData.Schema.GetColumnRoleNames();
            var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value;
            var labelColumnName   = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value;
            var pred = new RegressionPredictionTransformer <IPredictorProducing <float> >(
                env, predictor as IPredictorProducing <float>, roleMappedData.Data.Schema, featureColumnName);
            var regressionCatalog  = new RegressionCatalog(env);
            var permutationMetrics = regressionCatalog
                                     .PermutationFeatureImportance(pred,
                                                                   roleMappedData.Data,
                                                                   labelColumnName: labelColumnName,
                                                                   useFeatureWeightFilter: input.UseFeatureWeightFilter,
                                                                   numberOfExamplesToUse: input.NumberOfExamplesToUse,
                                                                   permutationCount: input.PermutationCount);

            var slotNames = GetSlotNames(roleMappedData.Schema);

            Contracts.Assert(slotNames.Length == permutationMetrics.Length,
                             "Mismatch between number of feature slots and number of features permuted.");

            List <RegressionMetrics> metrics = new List <RegressionMetrics>();

            for (int i = 0; i < permutationMetrics.Length; i++)
            {
                if (string.IsNullOrWhiteSpace(slotNames[i]))
                {
                    continue;
                }
                var pMetric = permutationMetrics[i];
                metrics.Add(new RegressionMetrics
                {
                    FeatureName                = slotNames[i],
                    MeanAbsoluteError          = pMetric.MeanAbsoluteError.Mean,
                    MeanAbsoluteErrorStdErr    = pMetric.MeanAbsoluteError.StandardError,
                    MeanSquaredError           = pMetric.MeanSquaredError.Mean,
                    MeanSquaredErrorStdErr     = pMetric.MeanSquaredError.StandardError,
                    RootMeanSquaredError       = pMetric.RootMeanSquaredError.Mean,
                    RootMeanSquaredErrorStdErr = pMetric.RootMeanSquaredError.StandardError,
                    LossFunction               = pMetric.LossFunction.Mean,
                    LossFunctionStdErr         = pMetric.LossFunction.StandardError,
                    RSquared       = pMetric.RSquared.Mean,
                    RSquaredStdErr = pMetric.RSquared.StandardError
                });
            }

            var dataOps = new DataOperationsCatalog(env);
            var result  = dataOps.LoadFromEnumerable(metrics);

            return(result);
        }
        private static IDataView GetBinaryMetrics(
            IHostEnvironment env,
            IPredictor predictor,
            RoleMappedData roleMappedData,
            PermutationFeatureImportanceArguments input)
        {
            var roles             = roleMappedData.Schema.GetColumnRoleNames();
            var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value;
            var labelColumnName   = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value;
            var pred = new BinaryPredictionTransformer <IPredictorProducing <float> >(
                env, predictor as IPredictorProducing <float>, roleMappedData.Data.Schema, featureColumnName);
            var binaryCatalog      = new BinaryClassificationCatalog(env);
            var permutationMetrics = binaryCatalog
                                     .PermutationFeatureImportance(pred,
                                                                   roleMappedData.Data,
                                                                   labelColumnName: labelColumnName,
                                                                   useFeatureWeightFilter: input.UseFeatureWeightFilter,
                                                                   numberOfExamplesToUse: input.NumberOfExamplesToUse,
                                                                   permutationCount: input.PermutationCount);

            var slotNames = GetSlotNames(roleMappedData.Schema);

            Contracts.Assert(slotNames.Length == permutationMetrics.Length,
                             "Mismatch between number of feature slots and number of features permuted.");

            List <BinaryMetrics> metrics = new List <BinaryMetrics>();

            for (int i = 0; i < permutationMetrics.Length; i++)
            {
                if (string.IsNullOrWhiteSpace(slotNames[i]))
                {
                    continue;
                }
                var pMetric = permutationMetrics[i];
                metrics.Add(new BinaryMetrics
                {
                    FeatureName             = slotNames[i],
                    AreaUnderRocCurve       = pMetric.AreaUnderRocCurve.Mean,
                    AreaUnderRocCurveStdErr = pMetric.AreaUnderRocCurve.StandardError,
                    Accuracy                = pMetric.Accuracy.Mean,
                    AccuracyStdErr          = pMetric.Accuracy.StandardError,
                    PositivePrecision       = pMetric.PositivePrecision.Mean,
                    PositivePrecisionStdErr = pMetric.PositivePrecision.StandardError,
                    PositiveRecall          = pMetric.PositiveRecall.Mean,
                    PositiveRecallStdErr    = pMetric.PositiveRecall.StandardError,
                    NegativePrecision       = pMetric.NegativePrecision.Mean,
                    NegativePrecisionStdErr = pMetric.NegativePrecision.StandardError,
                    NegativeRecall          = pMetric.NegativeRecall.Mean,
                    NegativeRecallStdErr    = pMetric.NegativeRecall.StandardError,
                    F1Score       = pMetric.F1Score.Mean,
                    F1ScoreStdErr = pMetric.F1Score.StandardError,
                    AreaUnderPrecisionRecallCurve       = pMetric.AreaUnderPrecisionRecallCurve.Mean,
                    AreaUnderPrecisionRecallCurveStdErr = pMetric.AreaUnderPrecisionRecallCurve.StandardError
                });
            }

            var dataOps = new DataOperationsCatalog(env);
            var result  = dataOps.LoadFromEnumerable(metrics);

            return(result);
        }
        public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IHostEnvironment env, PermutationFeatureImportanceArguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Pfi");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            input.PredictorModel.PrepareData(env, input.Data, out RoleMappedData roleMappedData, out IPredictor predictor);
            Contracts.Assert(predictor != null, "No predictor found in model");
            IDataView result = PermutationFeatureImportanceUtils.GetMetrics(env, predictor, roleMappedData, input);

            return(new PermutationFeatureImportanceOutput {
                Metrics = result
            });
        }
        private static IDataView GetMulticlassMetrics(
            IHostEnvironment env,
            IPredictor predictor,
            RoleMappedData roleMappedData,
            PermutationFeatureImportanceArguments input)
        {
            var roles             = roleMappedData.Schema.GetColumnRoleNames();
            var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value;
            var labelColumnName   = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value;
            var pred = new MulticlassPredictionTransformer <IPredictorProducing <VBuffer <float> > >(
                env, predictor as IPredictorProducing <VBuffer <float> >, roleMappedData.Data.Schema, featureColumnName, labelColumnName);
            var multiclassCatalog  = new MulticlassClassificationCatalog(env);
            var permutationMetrics = multiclassCatalog
                                     .PermutationFeatureImportance(pred,
                                                                   roleMappedData.Data,
                                                                   labelColumnName: labelColumnName,
                                                                   useFeatureWeightFilter: input.UseFeatureWeightFilter,
                                                                   numberOfExamplesToUse: input.NumberOfExamplesToUse,
                                                                   permutationCount: input.PermutationCount);

            var slotNames = GetSlotNames(roleMappedData.Schema);

            Contracts.Assert(slotNames.Length == permutationMetrics.Length,
                             "Mismatch between number of feature slots and number of features permuted.");

            List <MulticlassMetrics> metrics = new List <MulticlassMetrics>();

            for (int i = 0; i < permutationMetrics.Length; i++)
            {
                if (string.IsNullOrWhiteSpace(slotNames[i]))
                {
                    continue;
                }
                var pMetric = permutationMetrics[i];
                metrics.Add(new MulticlassMetrics
                {
                    FeatureName            = slotNames[i],
                    MacroAccuracy          = pMetric.MacroAccuracy.Mean,
                    MacroAccuracyStdErr    = pMetric.MacroAccuracy.StandardError,
                    MicroAccuracy          = pMetric.MicroAccuracy.Mean,
                    MicroAccuracyStdErr    = pMetric.MicroAccuracy.StandardError,
                    LogLoss                = pMetric.LogLoss.Mean,
                    LogLossStdErr          = pMetric.LogLoss.StandardError,
                    LogLossReduction       = pMetric.LogLossReduction.Mean,
                    LogLossReductionStdErr = pMetric.LogLossReduction.StandardError,
                    TopKAccuracy           = pMetric.TopKAccuracy.Mean,
                    TopKAccuracyStdErr     = pMetric.TopKAccuracy.StandardError,
                    PerClassLogLoss        = pMetric.PerClassLogLoss.Select(x => x.Mean).ToArray(),
                    PerClassLogLossStdErr  = pMetric.PerClassLogLoss.Select(x => x.StandardError).ToArray()
                });;
            }

            // Convert unknown size vectors to known size.
            var metric = metrics.First();
            SchemaDefinition schema = SchemaDefinition.Create(typeof(MulticlassMetrics));

            ConvertVectorToKnownSize(nameof(metric.PerClassLogLoss), metric.PerClassLogLoss.Length, ref schema);
            ConvertVectorToKnownSize(nameof(metric.PerClassLogLossStdErr), metric.PerClassLogLossStdErr.Length, ref schema);

            var dataOps = new DataOperationsCatalog(env);
            var result  = dataOps.LoadFromEnumerable(metrics, schema);

            return(result);
        }