예제 #1
0
 public void Save(ModelSaveContext ctx)
 {
     if (_contractName == null)
     {
         throw _host.Except("Empty contract name for a transform: the transform cannot be saved");
     }
     LambdaTransform.SaveCustomTransformer(_host, ctx, _contractName);
 }
        GetImportanceMetricsMatrix(
            IHostEnvironment env,
            IPredictionTransformer <TModel> model,
            IDataView data,
            Func <TResult> resultInitializer,
            Func <IDataView, TMetric> evaluationFunc,
            Func <TMetric, TMetric, TMetric> deltaFunc,
            string features,
            int permutationCount,
            bool useFeatureWeightFilter = false,
            int?topExamples             = null)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(nameof(PermutationFeatureImportance <TModel, TMetric, TResult>));

            host.CheckValue(model, nameof(model));
            host.CheckValue(data, nameof(data));
            host.CheckNonEmpty(features, nameof(features));

            topExamples = topExamples ?? Utils.ArrayMaxSize;
            host.Check(topExamples > 0, "Provide how many examples to use (positive number) or set to null to use whole dataset.");

            VBuffer <ReadOnlyMemory <char> > slotNames = default;
            var metricsDelta = new List <TResult>();

            using (var ch = host.Start("GetImportanceMetrics"))
            {
                ch.Trace("Scoring and evaluating baseline.");
                var baselineMetrics = evaluationFunc(model.Transform(data));

                // Get slot names.
                var featuresColumn = data.Schema[features];
                int numSlots       = featuresColumn.Type.GetVectorSize();
                data.Schema.TryGetColumnIndex(features, out int featuresColumnIndex);

                ch.Info("Number of slots: " + numSlots);
                if (data.Schema[featuresColumnIndex].HasSlotNames(numSlots))
                {
                    data.Schema[featuresColumnIndex].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref slotNames);
                }

                if (slotNames.Length != numSlots)
                {
                    slotNames = VBufferUtils.CreateEmpty <ReadOnlyMemory <char> >(numSlots);
                }

                VBuffer <float> weights = default;
                var             workingFeatureIndices = Enumerable.Range(0, numSlots).ToList();
                int             zeroWeightsCount      = 0;

                // By default set to the number of all features available.
                var evaluatedFeaturesCount = numSlots;
                if (useFeatureWeightFilter)
                {
                    var predictorWithWeights = model.Model as IPredictorWithFeatureWeights <Single>;
                    if (predictorWithWeights != null)
                    {
                        predictorWithWeights.GetFeatureWeights(ref weights);

                        const int     maxReportedZeroFeatures = 10;
                        StringBuilder msgFilteredOutFeatures  = new StringBuilder("The following features have zero weight and will not be evaluated: \n \t");
                        var           prefix = "";
                        foreach (var k in weights.Items(all: true))
                        {
                            if (k.Value == 0)
                            {
                                zeroWeightsCount++;

                                // Print info about first few features we're not going to evaluate.
                                if (zeroWeightsCount <= maxReportedZeroFeatures)
                                {
                                    msgFilteredOutFeatures.Append(prefix);
                                    msgFilteredOutFeatures.Append(GetSlotName(slotNames, k.Key));
                                    prefix = ", ";
                                }
                            }
                            else
                            {
                                workingFeatureIndices.Add(k.Key);
                            }
                        }

                        // Old FastTree models has less weights than slots.
                        if (weights.Length < numSlots)
                        {
                            ch.Warning(
                                "Predictor had fewer features than slots. All unknown features will get default 0 weight.");
                            zeroWeightsCount += numSlots - weights.Length;
                            var indexes = weights.GetIndices().ToArray();
                            var values  = weights.GetValues().ToArray();
                            var count   = values.Length;
                            weights = new VBuffer <float>(numSlots, count, values, indexes);
                        }

                        evaluatedFeaturesCount = workingFeatureIndices.Count;
                        ch.Info("Number of zero weights: {0} out of {1}.", zeroWeightsCount, weights.Length);

                        // Print what features have 0 weight
                        if (zeroWeightsCount > 0)
                        {
                            if (zeroWeightsCount > maxReportedZeroFeatures)
                            {
                                msgFilteredOutFeatures.Append(string.Format("... (printing out  {0} features here).\n Use 'Index' column in the report for info on what features are not evaluated.", maxReportedZeroFeatures));
                            }
                            ch.Info(msgFilteredOutFeatures.ToString());
                        }
                    }
                }

                if (workingFeatureIndices.Count == 0 && zeroWeightsCount == 0)
                {
                    // Use all features otherwise.
                    workingFeatureIndices.AddRange(Enumerable.Range(0, numSlots));
                }

                if (zeroWeightsCount == numSlots)
                {
                    ch.Warning("All features have 0 weight thus can not do thorough evaluation");
                    return(metricsDelta.ToImmutableArray());
                }

                // Note: this will not work on the huge dataset.
                var          maxSize = topExamples;
                List <float> initialfeatureValuesList = new List <float>();

                // Cursor through the data to cache slot 0 values for the upcoming permutation.
                var valuesRowCount = 0;
                // REVIEW: Seems like if the labels are NaN, so that all metrics are NaN, this command will be useless.
                // In which case probably erroring out is probably the most useful thing.
                using (var cursor = data.GetRowCursor(featuresColumn))
                {
                    var featuresGetter = cursor.GetGetter <VBuffer <float> >(featuresColumn);
                    var featuresBuffer = default(VBuffer <float>);

                    while (initialfeatureValuesList.Count < maxSize && cursor.MoveNext())
                    {
                        featuresGetter(ref featuresBuffer);
                        initialfeatureValuesList.Add(featuresBuffer.GetItemOrDefault(workingFeatureIndices[0]));
                    }

                    valuesRowCount = initialfeatureValuesList.Count;
                }

                if (valuesRowCount > 0)
                {
                    ch.Info("Detected {0} examples for evaluation.", valuesRowCount);
                }
                else
                {
                    ch.Warning("Detected no examples for evaluation.");
                    return(metricsDelta.ToImmutableArray());
                }

                float[] featureValuesBuffer = initialfeatureValuesList.ToArray();
                float[] nextValues          = new float[valuesRowCount];

                // Now iterate through all the working slots, do permutation and calc the delta of metrics.
                int processedCnt     = 0;
                int nextFeatureIndex = 0;
                var shuffleRand      = RandomUtils.Create(host.Rand.Next());
                using (var pch = host.StartProgressChannel("Calculating Permutation Feature Importance"))
                {
                    pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt));
                    foreach (var workingIndx in workingFeatureIndices)
                    {
                        // Index for the feature we will permute next.  Needed to build in advance a buffer for the permutation.
                        if (processedCnt < workingFeatureIndices.Count - 1)
                        {
                            nextFeatureIndex = workingFeatureIndices[processedCnt + 1];
                        }

                        // Used for pre-caching the next feature
                        int nextValuesIndex = 0;

                        SchemaDefinition input = SchemaDefinition.Create(typeof(FeaturesBuffer));
                        Contracts.Assert(input.Count == 1);
                        input[0].ColumnName = features;

                        SchemaDefinition output = SchemaDefinition.Create(typeof(FeaturesBuffer));
                        Contracts.Assert(output.Count == 1);
                        output[0].ColumnName = features;
                        output[0].ColumnType = featuresColumn.Type;

                        // Perform multiple permutations for one feature to build a confidence interval
                        var metricsDeltaForFeature = resultInitializer();
                        for (int permutationIteration = 0; permutationIteration < permutationCount; permutationIteration++)
                        {
                            Utils.Shuffle <float>(shuffleRand, featureValuesBuffer);

                            Action <FeaturesBuffer, FeaturesBuffer, PermuterState> permuter =
                                (src, dst, state) =>
                            {
                                src.Features.CopyTo(ref dst.Features);
                                VBufferUtils.ApplyAt(ref dst.Features, workingIndx,
                                                     (int ii, ref float d) =>
                                                     d = featureValuesBuffer[state.SampleIndex++]);

                                // Is it time to pre-cache the next feature?
                                if (permutationIteration == permutationCount - 1 &&
                                    processedCnt < workingFeatureIndices.Count - 1)
                                {
                                    // Fill out the featureValueBuffer for the next feature while updating the current feature
                                    // This is the reason I need PermuterState in LambdaTransform.CreateMap.
                                    nextValues[nextValuesIndex] = src.Features.GetItemOrDefault(nextFeatureIndex);
                                    if (nextValuesIndex < valuesRowCount - 1)
                                    {
                                        nextValuesIndex++;
                                    }
                                }
                            };

                            IDataView viewPermuted = LambdaTransform.CreateMap(
                                host, data, permuter, null, input, output);
                            if (valuesRowCount == topExamples)
                            {
                                viewPermuted = SkipTakeFilter.Create(host, new SkipTakeFilter.TakeOptions()
                                {
                                    Count = valuesRowCount
                                }, viewPermuted);
                            }

                            var metrics = evaluationFunc(model.Transform(viewPermuted));

                            var delta = deltaFunc(metrics, baselineMetrics);
                            metricsDeltaForFeature.Add(delta);
                        }

                        // Add the metrics delta to the list
                        metricsDelta.Add(metricsDeltaForFeature);

                        // Swap values for next iteration of permutation.
                        if (processedCnt < workingFeatureIndices.Count - 1)
                        {
                            Array.Clear(featureValuesBuffer, 0, featureValuesBuffer.Length);
                            nextValues.CopyTo(featureValuesBuffer, 0);
                            Array.Clear(nextValues, 0, nextValues.Length);
                        }
                        processedCnt++;
                    }
                    pch.Checkpoint(processedCnt, processedCnt);
                }
            }

            return(metricsDelta.ToImmutableArray());
        }