public static SuggestedPipeline GetNextInferredPipeline(MLContext context, IEnumerable <SuggestedPipelineRunDetail> history, DatasetColumnInfo[] columns, TaskKind task, bool isMaximizingMetric, CacheBeforeTrainer cacheBeforeTrainer, IChannel logger, IEnumerable <TrainerName> trainerAllowList = null) { var availableTrainers = RecipeInference.AllowedTrainers(context, task, ColumnInformationUtil.BuildColumnInfo(columns), trainerAllowList); var transforms = TransformInferenceApi.InferTransforms(context, task, columns).ToList(); var transformsPostTrainer = TransformInferenceApi.InferTransformsPostTrainer(context, task, columns).ToList(); // if we haven't run all pipelines once if (history.Count() < availableTrainers.Count()) { return(GetNextFirstStagePipeline(context, history, availableTrainers, transforms, transformsPostTrainer, cacheBeforeTrainer)); } // get top trainers from stage 1 runs var topTrainers = GetTopTrainers(history, availableTrainers, isMaximizingMetric); // sort top trainers by # of times they've been run, from lowest to highest var orderedTopTrainers = OrderTrainersByNumTrials(history, topTrainers); // keep as hash set of previously visited pipelines var visitedPipelines = new HashSet <SuggestedPipeline>(history.Select(h => h.Pipeline)); // iterate over top trainers (from least run to most run), // to find next pipeline foreach (var trainer in orderedTopTrainers) { var newTrainer = trainer.Clone(); // repeat until passes or runs out of chances const int maxNumberAttempts = 10; var count = 0; do { // sample new hyperparameters for the learner if (!SampleHyperparameters(context, newTrainer, history, isMaximizingMetric, logger)) { // if unable to sample new hyperparameters for the learner // (ie SMAC returned 0 suggestions), break break; } var suggestedPipeline = SuggestedPipelineBuilder.Build(context, transforms, transformsPostTrainer, newTrainer, cacheBeforeTrainer); // make sure we have not seen pipeline before if (!visitedPipelines.Contains(suggestedPipeline)) { return(suggestedPipeline); } } while (++count <= maxNumberAttempts); } return(null); }
public static ColumnInferenceResults InferColumns(MLContext context, string path, ColumnInformation columnInfo, bool hasHeader, TextFileContents.ColumnSplitResult splitInference, ColumnTypeInference.InferenceResult typeInference, bool trimWhitespace, bool groupColumns) { var loaderColumns = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns); var typedLoaderOptions = new TextLoader.Options { Columns = loaderColumns, Separators = new[] { splitInference.Separator.Value }, AllowSparse = splitInference.AllowSparse, AllowQuoting = splitInference.AllowQuote, ReadMultilines = splitInference.ReadMultilines, HasHeader = hasHeader, TrimWhitespace = trimWhitespace }; var textLoader = context.Data.CreateTextLoader(typedLoaderOptions); var dataView = textLoader.Load(path); // Validate all columns specified in column info exist in inferred data view ColumnInferenceValidationUtil.ValidateSpecifiedColumnsExist(columnInfo, dataView); var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, columnInfo); // start building result objects IEnumerable <TextLoader.Column> columnResults = null; IEnumerable <(string, ColumnPurpose)> purposeResults = null; // infer column grouping and generate column names if (groupColumns) { var groupingResult = ColumnGroupingInference.InferGroupingAndNames(context, hasHeader, typeInference.Columns, purposeInferenceResult); columnResults = groupingResult.Select(c => c.GenerateTextLoaderColumn()); purposeResults = groupingResult.Select(c => (c.SuggestedName, c.Purpose)); } else { columnResults = loaderColumns; purposeResults = purposeInferenceResult.Select(p => (dataView.Schema[p.ColumnIndex].Name, p.Purpose)); } var textLoaderOptions = new TextLoader.Options() { Columns = columnResults.ToArray(), AllowQuoting = splitInference.AllowQuote, AllowSparse = splitInference.AllowSparse, Separators = new char[] { splitInference.Separator.Value }, ReadMultilines = splitInference.ReadMultilines, HasHeader = hasHeader, TrimWhitespace = trimWhitespace }; return(new ColumnInferenceResults() { TextLoaderOptions = textLoaderOptions, ColumnInformation = ColumnInformationUtil.BuildColumnInfo(purposeResults) }); }
/// <summary> /// Validate all columns specified in column info exist in inferred data view. /// </summary> public static void ValidateSpecifiedColumnsExist(ColumnInformation columnInfo, IDataView dataView) { var columnNames = ColumnInformationUtil.GetColumnNames(columnInfo); foreach (var columnName in columnNames) { if (dataView.Schema.GetColumnOrNull(columnName) == null) { throw new ArgumentException($"Specified column {columnName} " + $"is not found in the dataset."); } } }