예제 #1
0
        /// <summary>
        /// Runs a train-test experiment on the current pipeline, through entrypoints.
        /// </summary>
        public void RunTrainTestExperiment(IDataView trainData, IDataView testData,
                                           SupportedMetric metric, MacroUtils.TrainerKinds trainerKind, out double testMetricValue,
                                           out double trainMetricValue)
        {
            var experiment = CreateTrainTestExperiment(trainData, testData, trainerKind, true, out var trainTestOutput);

            experiment.Run();

            var dataOut         = experiment.GetOutput(trainTestOutput.OverallMetrics);
            var dataOutTraining = experiment.GetOutput(trainTestOutput.TrainingOverallMetrics);

            testMetricValue  = AutoMlUtils.ExtractValueFromIdv(_env, dataOut, metric.Name);
            trainMetricValue = AutoMlUtils.ExtractValueFromIdv(_env, dataOutTraining, metric.Name);
        }
        public static AutoMlMlState InferPipelines(IHostEnvironment env, PipelineOptimizerBase autoMlEngine, IDataView data, int numTransformLevels,
                                                   int batchSize, SupportedMetric metric, out PipelinePattern bestPipeline, int numOfSampleRows,
                                                   ITerminator terminator, MacroUtils.TrainerKinds trainerKind)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));

            var splitOutput = TrainTestSplit.Split(env, new TrainTestSplit.Input {
                Data = data, Fraction = 0.8f
            });
            AutoMlMlState amls = new AutoMlMlState(env, metric, autoMlEngine, terminator, trainerKind,
                                                   splitOutput.TrainData.Take(numOfSampleRows), splitOutput.TestData.Take(numOfSampleRows));

            bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows);
            return(amls);
        }
 public AutoMlMlState(IHostEnvironment env, SupportedMetric metric, IPipelineOptimizer autoMlEngine,
                      ITerminator terminator, MacroUtils.TrainerKinds trainerKind, IDataView trainData = null, IDataView testData = null,
                      string[] requestedLearners = null)
 {
     Contracts.CheckValue(env, nameof(env));
     _sortedSampledElements =
         metric.IsMaximizing ? new SortedList <double, PipelinePattern>(new ReversedComparer <double>()) :
         new SortedList <double, PipelinePattern>();
     _history           = new List <PipelinePattern>();
     _env               = env;
     _host              = _env.Register("AutoMlState");
     _trainData         = trainData;
     _testData          = testData;
     _terminator        = terminator;
     _requestedLearners = requestedLearners;
     AutoMlEngine       = autoMlEngine;
     BatchCandidates    = new PipelinePattern[] { };
     Metric             = metric;
     TrainerKind        = trainerKind;
 }
        public static AutoMlMlState InferPipelines(IHostEnvironment env, PipelineOptimizerBase autoMlEngine, string trainDataPath,
                                                   string schemaDefinitionFile, out string schemaDefinition, int numTransformLevels, int batchSize, SupportedMetric metric,
                                                   out PipelinePattern bestPipeline, int numOfSampleRows, ITerminator terminator, MacroUtils.TrainerKinds trainerKind)
        {
            Contracts.CheckValue(env, nameof(env));

            // REVIEW: Should be able to infer schema by itself, without having to
            // infer recipes. Look into this.
            // Set loader settings through inference
            RecipeInference.InferRecipesFromData(env, trainDataPath, schemaDefinitionFile,
                                                 out var _, out schemaDefinition, out var _, true);

#pragma warning disable 0618
            var data = ImportTextData.ImportText(env, new ImportTextData.Input
            {
                InputFile    = new SimpleFileHandle(env, trainDataPath, false, false),
                CustomSchema = schemaDefinition
            }).Data;
#pragma warning restore 0618
            var splitOutput = TrainTestSplit.Split(env, new TrainTestSplit.Input {
                Data = data, Fraction = 0.8f
            });
            AutoMlMlState amls = new AutoMlMlState(env, metric, autoMlEngine, terminator, trainerKind,
                                                   splitOutput.TrainData.Take(numOfSampleRows), splitOutput.TestData.Take(numOfSampleRows));
            bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows);
            return(amls);
        }
        /// <summary>
        /// The InferPipelines methods are just public portals to the internal function that handle different
        /// types of data being passed in: training IDataView, path to training file, or train and test files.
        /// </summary>
        public static AutoMlMlState InferPipelines(IHostEnvironment env, PipelineOptimizerBase autoMlEngine,
                                                   IDataView trainData, IDataView testData, int numTransformLevels, int batchSize, SupportedMetric metric,
                                                   out PipelinePattern bestPipeline, ITerminator terminator, MacroUtils.TrainerKinds trainerKind)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(trainData, nameof(trainData));
            env.CheckValue(testData, nameof(testData));

            int           numOfRows = (int)(trainData.GetRowCount(false) ?? 1000);
            AutoMlMlState amls      = new AutoMlMlState(env, metric, autoMlEngine, terminator, trainerKind, trainData, testData);

            bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfRows);
            return(amls);
        }
 public AutoMlMlState(IHostEnvironment env, Arguments args)
     : this(env, SupportedMetric.ByName(Enum.GetName(typeof(Arguments.Metrics), args.Metric)), args.Engine.CreateComponent(env),
            args.TerminatorArgs.CreateComponent(env), args.TrainerKind, requestedLearners : args.RequestedLearners)
 {
 }
예제 #7
0
        public static SupportedMetric GetSupportedMetric(Metrics metric)
        {
            SupportedMetric supportedMetric = null;

            switch (metric)
            {
            case Metrics.Auc:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Auc, true);
                break;

            case Metrics.AccuracyMicro:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AccuracyMicro, true);
                break;

            case Metrics.AccuracyMacro:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AccuracyMacro, true);
                break;

            case Metrics.L1:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.L1, false);
                break;

            case Metrics.L2:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.L2, false);
                break;

            case Metrics.F1:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.F1, true);
                break;

            case Metrics.AuPrc:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AuPrc, true);
                break;

            case Metrics.TopKAccuracy:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.TopKAccuracy, true);
                break;

            case Metrics.Rms:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Rms, false);
                break;

            case Metrics.LossFn:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.LossFn, false);
                break;

            case Metrics.RSquared:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.RSquared, false);
                break;

            case Metrics.LogLoss:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.LogLoss, false);
                break;

            case Metrics.LogLossReduction:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.LogLossReduction, true);
                break;

            case Metrics.Ndcg:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Ndcg, true);
                break;

            case Metrics.Dcg:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Dcg, true);
                break;

            case Metrics.PositivePrecision:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.PositivePrecision, true);
                break;

            case Metrics.PositiveRecall:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.PositiveRecall, true);
                break;

            case Metrics.NegativePrecision:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.NegativePrecision, true);
                break;

            case Metrics.NegativeRecall:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.NegativeRecall, true);
                break;

            case Metrics.DrAtK:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.DrAtK, true);
                break;

            case Metrics.DrAtPFpr:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.DrAtPFpr, true);
                break;

            case Metrics.DrAtNumPos:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.DrAtNumPos, true);
                break;

            case Metrics.NumAnomalies:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.NumAnomalies, true);
                break;

            case Metrics.ThreshAtK:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.ThreshAtK, false);
                break;

            case Metrics.ThreshAtP:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.ThreshAtP, false);
                break;

            case Metrics.ThreshAtNumPos:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.ThreshAtNumPos, false);
                break;

            case Metrics.Nmi:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Nmi, true);
                break;

            case Metrics.AvgMinScore:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AvgMinScore, false);
                break;

            case Metrics.Dbi:
                supportedMetric = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Dbi, false);
                break;

            default:
                throw new NotSupportedException($"Metric '{metric}' not supported.");
            }
            return(supportedMetric);
        }