/// <summary> /// Runs a train-test experiment on the current pipeline, through entrypoints. /// </summary> public void RunTrainTestExperiment(IDataView trainData, IDataView testData, SupportedMetric metric, MacroUtils.TrainerKinds trainerKind, out double testMetricValue, out double trainMetricValue) { var experiment = CreateTrainTestExperiment(trainData, testData, trainerKind, true, out var trainTestOutput); experiment.Run(); var dataOut = experiment.GetOutput(trainTestOutput.OverallMetrics); var dataOutTraining = experiment.GetOutput(trainTestOutput.TrainingOverallMetrics); testMetricValue = AutoMlUtils.ExtractValueFromIdv(_env, dataOut, metric.Name); trainMetricValue = AutoMlUtils.ExtractValueFromIdv(_env, dataOutTraining, metric.Name); }
public static AutoMlMlState InferPipelines(IHostEnvironment env, PipelineOptimizerBase autoMlEngine, IDataView data, int numTransformLevels, int batchSize, SupportedMetric metric, out PipelinePattern bestPipeline, int numOfSampleRows, ITerminator terminator, MacroUtils.TrainerKinds trainerKind) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(data, nameof(data)); var splitOutput = TrainTestSplit.Split(env, new TrainTestSplit.Input { Data = data, Fraction = 0.8f }); AutoMlMlState amls = new AutoMlMlState(env, metric, autoMlEngine, terminator, trainerKind, splitOutput.TrainData.Take(numOfSampleRows), splitOutput.TestData.Take(numOfSampleRows)); bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows); return(amls); }
public AutoMlMlState(IHostEnvironment env, SupportedMetric metric, IPipelineOptimizer autoMlEngine, ITerminator terminator, MacroUtils.TrainerKinds trainerKind, IDataView trainData = null, IDataView testData = null, string[] requestedLearners = null) { Contracts.CheckValue(env, nameof(env)); _sortedSampledElements = metric.IsMaximizing ? new SortedList <double, PipelinePattern>(new ReversedComparer <double>()) : new SortedList <double, PipelinePattern>(); _history = new List <PipelinePattern>(); _env = env; _host = _env.Register("AutoMlState"); _trainData = trainData; _testData = testData; _terminator = terminator; _requestedLearners = requestedLearners; AutoMlEngine = autoMlEngine; BatchCandidates = new PipelinePattern[] { }; Metric = metric; TrainerKind = trainerKind; }
public static AutoMlMlState InferPipelines(IHostEnvironment env, PipelineOptimizerBase autoMlEngine, string trainDataPath, string schemaDefinitionFile, out string schemaDefinition, int numTransformLevels, int batchSize, SupportedMetric metric, out PipelinePattern bestPipeline, int numOfSampleRows, ITerminator terminator, MacroUtils.TrainerKinds trainerKind) { Contracts.CheckValue(env, nameof(env)); // REVIEW: Should be able to infer schema by itself, without having to // infer recipes. Look into this. // Set loader settings through inference RecipeInference.InferRecipesFromData(env, trainDataPath, schemaDefinitionFile, out var _, out schemaDefinition, out var _, true); #pragma warning disable 0618 var data = ImportTextData.ImportText(env, new ImportTextData.Input { InputFile = new SimpleFileHandle(env, trainDataPath, false, false), CustomSchema = schemaDefinition }).Data; #pragma warning restore 0618 var splitOutput = TrainTestSplit.Split(env, new TrainTestSplit.Input { Data = data, Fraction = 0.8f }); AutoMlMlState amls = new AutoMlMlState(env, metric, autoMlEngine, terminator, trainerKind, splitOutput.TrainData.Take(numOfSampleRows), splitOutput.TestData.Take(numOfSampleRows)); bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows); return(amls); }
/// <summary> /// The InferPipelines methods are just public portals to the internal function that handle different /// types of data being passed in: training IDataView, path to training file, or train and test files. /// </summary> public static AutoMlMlState InferPipelines(IHostEnvironment env, PipelineOptimizerBase autoMlEngine, IDataView trainData, IDataView testData, int numTransformLevels, int batchSize, SupportedMetric metric, out PipelinePattern bestPipeline, ITerminator terminator, MacroUtils.TrainerKinds trainerKind) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(trainData, nameof(trainData)); env.CheckValue(testData, nameof(testData)); int numOfRows = (int)(trainData.GetRowCount(false) ?? 1000); AutoMlMlState amls = new AutoMlMlState(env, metric, autoMlEngine, terminator, trainerKind, trainData, testData); bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfRows); return(amls); }
public AutoMlMlState(IHostEnvironment env, Arguments args) : this(env, SupportedMetric.ByName(Enum.GetName(typeof(Arguments.Metrics), args.Metric)), args.Engine.CreateComponent(env), args.TerminatorArgs.CreateComponent(env), args.TrainerKind, requestedLearners : args.RequestedLearners) { }
public static SupportedMetric GetSupportedMetric(Metrics metric) { SupportedMetric supportedMetric = null; switch (metric) { case Metrics.Auc: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Auc, true); break; case Metrics.AccuracyMicro: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AccuracyMicro, true); break; case Metrics.AccuracyMacro: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AccuracyMacro, true); break; case Metrics.L1: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.L1, false); break; case Metrics.L2: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.L2, false); break; case Metrics.F1: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.F1, true); break; case Metrics.AuPrc: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AuPrc, true); break; case Metrics.TopKAccuracy: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.TopKAccuracy, true); break; case Metrics.Rms: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Rms, false); break; case Metrics.LossFn: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.LossFn, false); break; case Metrics.RSquared: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.RSquared, false); break; case Metrics.LogLoss: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.LogLoss, false); break; case Metrics.LogLossReduction: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.LogLossReduction, true); break; case Metrics.Ndcg: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Ndcg, true); break; case Metrics.Dcg: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Dcg, true); break; case Metrics.PositivePrecision: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.PositivePrecision, true); break; case Metrics.PositiveRecall: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.PositiveRecall, true); break; case Metrics.NegativePrecision: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.NegativePrecision, true); break; case Metrics.NegativeRecall: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.NegativeRecall, true); break; case Metrics.DrAtK: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.DrAtK, true); break; case Metrics.DrAtPFpr: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.DrAtPFpr, true); break; case Metrics.DrAtNumPos: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.DrAtNumPos, true); break; case Metrics.NumAnomalies: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.NumAnomalies, true); break; case Metrics.ThreshAtK: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.ThreshAtK, false); break; case Metrics.ThreshAtP: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.ThreshAtP, false); break; case Metrics.ThreshAtNumPos: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.ThreshAtNumPos, false); break; case Metrics.Nmi: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Nmi, true); break; case Metrics.AvgMinScore: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AvgMinScore, false); break; case Metrics.Dbi: supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Dbi, false); break; default: throw new NotSupportedException($"Metric '{metric}' not supported."); } return(supportedMetric); }