コード例 #1
0
        public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
                                                                IEnumerable <SuggestedPipelineRunDetail> history,
                                                                DatasetColumnInfo[] columns,
                                                                TaskKind task,
                                                                bool isMaximizingMetric,
                                                                CacheBeforeTrainer cacheBeforeTrainer,
                                                                IChannel logger,
                                                                IEnumerable <TrainerName> trainerAllowList = null)
        {
            var availableTrainers = RecipeInference.AllowedTrainers(context, task,
                                                                    ColumnInformationUtil.BuildColumnInfo(columns), trainerAllowList);
            var transforms            = TransformInferenceApi.InferTransforms(context, task, columns).ToList();
            var transformsPostTrainer = TransformInferenceApi.InferTransformsPostTrainer(context, task, columns).ToList();

            // if we haven't run all pipelines once
            if (history.Count() < availableTrainers.Count())
            {
                return(GetNextFirstStagePipeline(context, history, availableTrainers, transforms, transformsPostTrainer, cacheBeforeTrainer));
            }

            // get top trainers from stage 1 runs
            var topTrainers = GetTopTrainers(history, availableTrainers, isMaximizingMetric);

            // sort top trainers by # of times they've been run, from lowest to highest
            var orderedTopTrainers = OrderTrainersByNumTrials(history, topTrainers);

            // keep as hash set of previously visited pipelines
            var visitedPipelines = new HashSet <SuggestedPipeline>(history.Select(h => h.Pipeline));

            // iterate over top trainers (from least run to most run),
            // to find next pipeline
            foreach (var trainer in orderedTopTrainers)
            {
                var newTrainer = trainer.Clone();

                // repeat until passes or runs out of chances
                const int maxNumberAttempts = 10;
                var       count             = 0;
                do
                {
                    // sample new hyperparameters for the learner
                    if (!SampleHyperparameters(context, newTrainer, history, isMaximizingMetric, logger))
                    {
                        // if unable to sample new hyperparameters for the learner
                        // (ie SMAC returned 0 suggestions), break
                        break;
                    }

                    var suggestedPipeline = SuggestedPipelineBuilder.Build(context, transforms, transformsPostTrainer, newTrainer, cacheBeforeTrainer);

                    // make sure we have not seen pipeline before
                    if (!visitedPipelines.Contains(suggestedPipeline))
                    {
                        return(suggestedPipeline);
                    }
                } while (++count <= maxNumberAttempts);
            }

            return(null);
        }
コード例 #2
0
        private static SuggestedPipeline GetNextFirstStagePipeline(MLContext context,
                                                                   IEnumerable <SuggestedPipelineRunDetail> history,
                                                                   IEnumerable <SuggestedTrainer> availableTrainers,
                                                                   ICollection <SuggestedTransform> transforms,
                                                                   ICollection <SuggestedTransform> transformsPostTrainer,
                                                                   CacheBeforeTrainer cacheBeforeTrainer)
        {
            var trainer = availableTrainers.ElementAt(history.Count());

            return(SuggestedPipelineBuilder.Build(context, transforms, transformsPostTrainer, trainer, cacheBeforeTrainer));
        }