Esempio n. 1
0
        public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
                                                                IEnumerable <SuggestedPipelineRunDetail> history,
                                                                DatasetColumnInfo[] columns,
                                                                TaskKind task,
                                                                bool isMaximizingMetric,
                                                                CacheBeforeTrainer cacheBeforeTrainer,
                                                                IChannel logger,
                                                                IEnumerable <TrainerName> trainerAllowList = null)
        {
            var availableTrainers = RecipeInference.AllowedTrainers(context, task,
                                                                    ColumnInformationUtil.BuildColumnInfo(columns), trainerAllowList);
            var transforms            = TransformInferenceApi.InferTransforms(context, task, columns).ToList();
            var transformsPostTrainer = TransformInferenceApi.InferTransformsPostTrainer(context, task, columns).ToList();

            // if we haven't run all pipelines once
            if (history.Count() < availableTrainers.Count())
            {
                return(GetNextFirstStagePipeline(context, history, availableTrainers, transforms, transformsPostTrainer, cacheBeforeTrainer));
            }

            // get top trainers from stage 1 runs
            var topTrainers = GetTopTrainers(history, availableTrainers, isMaximizingMetric);

            // sort top trainers by # of times they've been run, from lowest to highest
            var orderedTopTrainers = OrderTrainersByNumTrials(history, topTrainers);

            // keep as hash set of previously visited pipelines
            var visitedPipelines = new HashSet <SuggestedPipeline>(history.Select(h => h.Pipeline));

            // iterate over top trainers (from least run to most run),
            // to find next pipeline
            foreach (var trainer in orderedTopTrainers)
            {
                var newTrainer = trainer.Clone();

                // repeat until passes or runs out of chances
                const int maxNumberAttempts = 10;
                var       count             = 0;
                do
                {
                    // sample new hyperparameters for the learner
                    if (!SampleHyperparameters(context, newTrainer, history, isMaximizingMetric, logger))
                    {
                        // if unable to sample new hyperparameters for the learner
                        // (ie SMAC returned 0 suggestions), break
                        break;
                    }

                    var suggestedPipeline = SuggestedPipelineBuilder.Build(context, transforms, transformsPostTrainer, newTrainer, cacheBeforeTrainer);

                    // make sure we have not seen pipeline before
                    if (!visitedPipelines.Contains(suggestedPipeline))
                    {
                        return(suggestedPipeline);
                    }
                } while (++count <= maxNumberAttempts);
            }

            return(null);
        }
Esempio n. 2
0
        public static ColumnInferenceResults InferColumns(MLContext context, string path, ColumnInformation columnInfo, bool hasHeader,
                                                          TextFileContents.ColumnSplitResult splitInference, ColumnTypeInference.InferenceResult typeInference,
                                                          bool trimWhitespace, bool groupColumns)
        {
            var loaderColumns      = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns);
            var typedLoaderOptions = new TextLoader.Options
            {
                Columns        = loaderColumns,
                Separators     = new[] { splitInference.Separator.Value },
                AllowSparse    = splitInference.AllowSparse,
                AllowQuoting   = splitInference.AllowQuote,
                ReadMultilines = splitInference.ReadMultilines,
                HasHeader      = hasHeader,
                TrimWhitespace = trimWhitespace
            };
            var textLoader = context.Data.CreateTextLoader(typedLoaderOptions);
            var dataView   = textLoader.Load(path);

            // Validate all columns specified in column info exist in inferred data view
            ColumnInferenceValidationUtil.ValidateSpecifiedColumnsExist(columnInfo, dataView);

            var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, columnInfo);

            // start building result objects
            IEnumerable <TextLoader.Column>       columnResults  = null;
            IEnumerable <(string, ColumnPurpose)> purposeResults = null;

            // infer column grouping and generate column names
            if (groupColumns)
            {
                var groupingResult = ColumnGroupingInference.InferGroupingAndNames(context, hasHeader,
                                                                                   typeInference.Columns, purposeInferenceResult);

                columnResults  = groupingResult.Select(c => c.GenerateTextLoaderColumn());
                purposeResults = groupingResult.Select(c => (c.SuggestedName, c.Purpose));
            }
            else
            {
                columnResults  = loaderColumns;
                purposeResults = purposeInferenceResult.Select(p => (dataView.Schema[p.ColumnIndex].Name, p.Purpose));
            }

            var textLoaderOptions = new TextLoader.Options()
            {
                Columns        = columnResults.ToArray(),
                AllowQuoting   = splitInference.AllowQuote,
                AllowSparse    = splitInference.AllowSparse,
                Separators     = new char[] { splitInference.Separator.Value },
                ReadMultilines = splitInference.ReadMultilines,
                HasHeader      = hasHeader,
                TrimWhitespace = trimWhitespace
            };

            return(new ColumnInferenceResults()
            {
                TextLoaderOptions = textLoaderOptions,
                ColumnInformation = ColumnInformationUtil.BuildColumnInfo(purposeResults)
            });
        }
        /// <summary>
        /// Validate all columns specified in column info exist in inferred data view.
        /// </summary>
        public static void ValidateSpecifiedColumnsExist(ColumnInformation columnInfo,
                                                         IDataView dataView)
        {
            var columnNames = ColumnInformationUtil.GetColumnNames(columnInfo);

            foreach (var columnName in columnNames)
            {
                if (dataView.Schema.GetColumnOrNull(columnName) == null)
                {
                    throw new ArgumentException($"Specified column {columnName} " +
                                                $"is not found in the dataset.");
                }
            }
        }