public static AutoMlMlState InferPipelines(IHostEnvironment env, PipelineOptimizerBase autoMlEngine, string trainDataPath,
                                                   string schemaDefinitionFile, out string schemaDefinition, int numTransformLevels, int batchSize, SupportedMetric metric,
                                                   out PipelinePattern bestPipeline, int numOfSampleRows, ITerminator terminator, MacroUtils.TrainerKinds trainerKind)
        {
            Contracts.CheckValue(env, nameof(env));

            // REVIEW: Should be able to infer schema by itself, without having to
            // infer recipes. Look into this.
            // Set loader settings through inference
            RecipeInference.InferRecipesFromData(env, trainDataPath, schemaDefinitionFile,
                                                 out var _, out schemaDefinition, out var _, true);

#pragma warning disable 0618
            var data = ImportTextData.ImportText(env, new ImportTextData.Input
            {
                InputFile    = new SimpleFileHandle(env, trainDataPath, false, false),
                CustomSchema = schemaDefinition
            }).Data;
#pragma warning restore 0618
            var splitOutput = TrainTestSplit.Split(env, new TrainTestSplit.Input {
                Data = data, Fraction = 0.8f
            });
            AutoMlMlState amls = new AutoMlMlState(env, metric, autoMlEngine, terminator, trainerKind,
                                                   splitOutput.TrainData.Take(numOfSampleRows), splitOutput.TestData.Take(numOfSampleRows));
            bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows);
            return(amls);
        }
            public void KeepSelectedLearners(IEnumerable <string> learnersToKeep)
            {
                var allLearners = RecipeInference.AllowedLearners(_env, TrainerKind);

                _env.AssertNonEmpty(allLearners);
                _availableLearners = allLearners.Where(l => learnersToKeep.Contains(l.LearnerName)).ToArray();
                AutoMlEngine.UpdateLearners(_availableLearners);
            }
            public void InferSearchSpace(int numTransformLevels)
            {
                var learners = RecipeInference.AllowedLearners(_env, TrainerKind).ToArray();

                if (_requestedLearners != null && _requestedLearners.Length > 0)
                {
                    learners = learners.Where(l => _requestedLearners.Contains(l.LearnerName)).ToArray();
                }
                ComputeSearchSpace(numTransformLevels, learners, (b, c) => InferAndFilter(b, c));
            }
示例#4
0
        public static List <Sweep> GenerateCandidates(IHostEnvironment env, string dataFile, string schemaDefinitionFile)
        {
            var    patterns = new List <Sweep>();
            string loaderSettings;
            Type   predictorType;

            TransformInference.InferenceResult inferenceResult;

            // Get the initial recipes for this data.
            RecipeInference.SuggestedRecipe[] recipes = RecipeInference.InferRecipesFromData(env, dataFile, schemaDefinitionFile, out predictorType, out loaderSettings, out inferenceResult);

            //get all the trainers for this task, and generate the initial set of candidates.
            // Exclude the hidden learners, and the metalinear learners.
            var trainers = env.ComponentCatalog.GetAllDerivedClasses(typeof(ITrainer), predictorType).Where(cls => !cls.IsHidden);

            if (!string.IsNullOrEmpty(loaderSettings))
            {
                StringBuilder sb = new StringBuilder();
                CmdQuoter.QuoteValue(loaderSettings, sb, true);
                loaderSettings = sb.ToString();
            }

            string loader = $" loader=TextLoader{loaderSettings}";

            // REVIEW: there are more learners than recipes atm.
            // Flip looping through recipes, then through learners if the cardinality changes.
            foreach (ComponentCatalog.LoadableClassInfo cl in trainers)
            {
                string         learnerSettings;
                TrainerSweeper trainerSweeper = new TrainerSweeper();
                trainerSweeper.Parameters.AddRange(RecipeInference.GetLearnerSettingsAndSweepParams(env, cl, out learnerSettings));

                foreach (var recipe in recipes)
                {
                    RecipeInference.SuggestedRecipe.SuggestedLearner learner = new RecipeInference.SuggestedRecipe.SuggestedLearner
                    {
                        LoadableClassInfo = cl,
                        Settings          = learnerSettings
                    };

                    Pattern pattern = new Pattern(recipe.Transforms, learner, loader);
                    Sweep   sweep   = new Sweep(pattern, trainerSweeper);
                    patterns.Add(sweep);
                }
            }

            return(patterns);
        }
            private bool IsValidLearnerSet(RecipeInference.SuggestedRecipe.SuggestedLearner[] learners)
            {
                var inferredLearners = RecipeInference.AllowedLearners(_env, TrainerKind);

                return(learners.All(l => inferredLearners.Any(i => i.LearnerName == l.LearnerName)));
            }