예제 #1
0
        /// <summary>
        /// Gets a final transform to concatenate all numeric columns into a "Features" vector column.
        /// Note: May return empty set if Features column already present and is only relevant numeric column.
        /// (In other words, if there would be nothing for that concatenate transform to do.)
        /// </summary>
        private static TransformInference.SuggestedTransform[] GetFinalFeatureConcat(IHostEnvironment env,
                                                                                     IDataView dataSample, int[] excludedColumnIndices, int level, int atomicIdOffset, RoleMappedData dataRoles)
        {
            var finalArgs = new TransformInference.Arguments
            {
                EstimatedSampleFraction         = 1.0,
                ExcludeFeaturesConcatTransforms = false,
                ExcludedColumnIndices           = excludedColumnIndices
            };

            var featuresConcatTransforms = TransformInference.InferConcatNumericFeatures(env, dataSample, finalArgs, dataRoles);

            for (int i = 0; i < featuresConcatTransforms.Length; i++)
            {
                featuresConcatTransforms[i].RoutingStructure.Level = level;
                featuresConcatTransforms[i].AtomicGroupId         += atomicIdOffset;
            }

            return(featuresConcatTransforms.ToArray());
        }
            /// <summary>
            /// Search space is transforms X learners X hyperparameters.
            /// </summary>
            private void ComputeSearchSpace(int numTransformLevels, RecipeInference.SuggestedRecipe.SuggestedLearner[] learners,
                                            Func <IDataView, TransformInference.Arguments, TransformInference.SuggestedTransform[]> transformInferenceFunction)
            {
                _env.AssertValue(_trainData, nameof(_trainData), "Must set training data prior to inferring search space.");

                var h = _env.Register("ComputeSearchSpace");

                using (var ch = h.Start("ComputeSearchSpace"))
                {
                    _env.Check(IsValidLearnerSet(learners), "Unsupported learner encountered, cannot update search space.");

                    var dataSample    = _trainData;
                    var inferenceArgs = new TransformInference.Arguments
                    {
                        EstimatedSampleFraction         = 1.0,
                        ExcludeFeaturesConcatTransforms = true
                    };

                    // Initialize structure for mapping columns back to specific transforms
                    var dependencyMapping = new DependencyMap
                    {
                        { 0, AutoMlUtils.ComputeColumnResponsibilities(dataSample, new TransformInference.SuggestedTransform[0]) }
                    };

                    // Get suggested transforms for all levels. Defines another part of search space.
                    var transformsList = new List <TransformInference.SuggestedTransform>();
                    for (int i = 0; i < numTransformLevels; i++)
                    {
                        // Update level for transforms
                        inferenceArgs.Level = i + 1;

                        // Infer transforms using experts
                        var levelTransforms = transformInferenceFunction(dataSample, inferenceArgs);

                        // If no more transforms to apply, dataSample won't change. So end loop.
                        if (levelTransforms.Length == 0)
                        {
                            break;
                        }

                        // Make sure we don't overflow our bitmask
                        if (levelTransforms.Max(t => t.AtomicGroupId) > 64)
                        {
                            break;
                        }

                        // Level-up atomic group id offset.
                        inferenceArgs.AtomicIdOffset = levelTransforms.Max(t => t.AtomicGroupId) + 1;

                        // Apply transforms to dataview for this level.
                        dataSample = AutoMlUtils.ApplyTransformSet(_env, dataSample, levelTransforms);

                        // Keep list of which transforms can be responsible for which output columns
                        dependencyMapping.Add(inferenceArgs.Level,
                                              AutoMlUtils.ComputeColumnResponsibilities(dataSample, levelTransforms));
                        transformsList.AddRange(levelTransforms);
                    }

                    var transforms = transformsList.ToArray();
                    Func <PipelinePattern, long, bool> verifier = AutoMlUtils.ValidationWrapper(transforms, dependencyMapping);

                    // Save state, for resuming learning
                    _availableTransforms = transforms;
                    _availableLearners   = learners;
                    _dependencyMapping   = dependencyMapping;
                    _transformedData     = dataSample;

                    // Update autoML engine to know what the search space looks like
                    AutoMlEngine.SetSpace(_availableTransforms, _availableLearners, verifier,
                                          _trainData, _transformedData, _dependencyMapping, Metric.IsMaximizing);

                    ch.Done();
                }
            }
            private TransformInference.SuggestedTransform[] InferAndFilter(IDataView data, TransformInference.Arguments args,
                                                                           TransformInference.SuggestedTransform[] existingTransforms = null)
            {
                // Infer transforms using experts
                var levelTransforms = TransformInference.InferTransforms(_env, data, args);

                // Retain only those transforms inferred which were also passed in.
                if (existingTransforms != null)
                {
                    return(levelTransforms.Where(t => existingTransforms.Any(t2 => t2.Equals(t))).ToArray());
                }
                return(levelTransforms);
            }