private FoldResult RunFold(int fold)
            {
                var host = GetHost();

                host.Assert(0 <= fold && fold <= _numFolds);
                // REVIEW: Make channels buffered in multi-threaded environments.
                using (var ch = host.Start($"Fold {fold}"))
                {
                    ch.Trace("Constructing trainer");
                    ITrainer trainer = _trainer.CreateComponent(host);

                    // Train pipe.
                    var trainFilter = new RangeFilter.Arguments();
                    trainFilter.Column     = _splitColumn;
                    trainFilter.Min        = (Double)fold / _numFolds;
                    trainFilter.Max        = (Double)(fold + 1) / _numFolds;
                    trainFilter.Complement = true;
                    IDataView trainPipe = new RangeFilter(host, trainFilter, _inputDataView);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = _createExamples(host, ch, trainPipe, trainer);

                    // Test pipe.
                    var testFilter = new RangeFilter.Arguments();
                    testFilter.Column = trainFilter.Column;
                    testFilter.Min    = trainFilter.Min;
                    testFilter.Max    = trainFilter.Max;
                    ch.Assert(!testFilter.Complement);
                    IDataView testPipe = new RangeFilter(host, testFilter, _inputDataView);
                    testPipe = new OpaqueDataView(testPipe);
                    var testData = _applyTransformsToTestData(host, ch, testPipe, trainData, trainPipe);

                    // Validation pipe and examples.
                    RoleMappedData validData = null;
                    if (_getValidationDataView != null)
                    {
                        ch.Assert(_applyTransformsToValidationData != null);
                        if (!trainer.Info.SupportsValidation)
                        {
                            ch.Warning("Trainer does not accept validation dataset.");
                        }
                        else
                        {
                            ch.Trace("Constructing the validation pipeline");
                            IDataView validLoader = _getValidationDataView();
                            var       validPipe   = ApplyTransformUtils.ApplyAllTransformsToData(host, _inputDataView, validLoader);
                            validPipe = new OpaqueDataView(validPipe);
                            validData = _applyTransformsToValidationData(host, ch, validPipe, trainData, trainPipe);
                        }
                    }

                    // Train.
                    var predictor = TrainUtils.Train(host, ch, trainData, trainer, validData,
                                                     _calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor);

                    // Score.
                    ch.Trace("Scoring and evaluating");
                    ch.Assert(_scorer == null || _scorer is ICommandLineComponentFactory, "CrossValidationCommand should only be used from the command line.");
                    var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: _scorer as ICommandLineComponentFactory);
                    ch.AssertValue(bindable);
                    var mapper     = bindable.Bind(host, testData.Schema);
                    var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(host, mapper);
                    IDataScorerTransform scorePipe = scorerComp.CreateComponent(host, testData.Data, mapper, trainData.Schema);

                    // Save per-fold model.
                    string modelFileName = ConstructPerFoldName(_outputModelFile, fold);
                    if (modelFileName != null && _loader != null)
                    {
                        using (var file = host.CreateOutputFile(modelFileName))
                        {
                            var rmd = new RoleMappedData(
                                CompositeDataLoader.ApplyTransform(host, _loader, null, null,
                                                                   (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
                                trainData.Schema.GetColumnRoleNames());
                            TrainUtils.SaveModel(host, ch, file, predictor, rmd, _cmd);
                        }
                    }

                    // Evaluate.
                    var eval = _evaluator?.CreateComponent(host) ??
                               EvaluateUtils.GetEvaluator(host, scorePipe.Schema);
                    // Note that this doesn't require the provided columns to exist (because of the "opt" parameter).
                    // We don't normally expect the scorer to drop columns, but if it does, we should not require
                    // all the columns in the test pipeline to still be present.
                    var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true);

                    var            dict        = eval.Evaluate(dataEval);
                    RoleMappedData perInstance = null;
                    if (_savePerInstance)
                    {
                        var perInst = eval.GetPerInstanceMetrics(dataEval);
                        perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true);
                    }
                    return(new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema));
                }
            }
コード例 #2
0
        public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema)
        {
            var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, null, out var mapper);

            return(sc.CreateComponent(env, data.Data, mapper, trainSchema));
        }
コード例 #3
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = Args.Trainer.CreateInstance(Host);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing the training pipeline");
            IDataView trainPipe = CreateLoader();

            ISchema schema = trainPipe.Schema;
            string  label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                                                                 Args.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
                                                                  Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                                                               Args.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                                                                Args.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn),
                                                              Args.NameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures);

            ch.Trace("Binding columns");
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var data       = TrainUtils.CreateExamples(trainPipe, label, features, group, weight, name, customCols);

            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
            {
                if (!TrainUtils.CanUseValidationData(trainer))
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
                    validData = RoleMappedData.Create(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData,
                                             Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor);

            IDataLoader testPipe;

            using (var file = !string.IsNullOrEmpty(Args.OutputModelFile) ?
                              Host.CreateOutputFile(Args.OutputModelFile) : Host.CreateTempFile(".zip"))
            {
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);

                ch.Trace("Constructing the testing pipeline");
                using (var stream = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(stream, ch))
                        testPipe = LoadLoader(rep, Args.TestFile, true);
            }

            // Score.
            ch.Trace("Scoring and evaluating");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema);

            // Evaluate.
            var evalComp = Args.Evaluator;

            if (!evalComp.IsGood())
            {
                evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
            }
            var evaluator = evalComp.CreateInstance(Host);
            var dataEval  = TrainUtils.CreateExamplesOpt(scorePipe, label, features,
                                                         group, weight, name, customCols);
            var metrics = evaluator.Evaluate(dataEval);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            evaluator.PrintOverallResults(ch, Args.SummaryFilename, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(dataEval);
                var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
            }
        }
 public FoldResult(Dictionary <string, IDataView> metrics, Schema scoreSchema, RoleMappedData perInstance, RoleMappedSchema trainSchema)
 {
     Metrics            = metrics;
     ScoreSchema        = scoreSchema;
     PerInstanceResults = perInstance;
     TrainSchema        = trainSchema;
 }
コード例 #5
0
        /// <summary>
        /// Save the model to the stream.
        /// The method saves the loader and the transformations of dataPipe and saves optionally predictor
        /// and command. It also uses featureColumn, if provided, to extract feature names.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="ch">The communication channel to use.</param>
        /// <param name="outputStream">The output model stream.</param>
        /// <param name="predictor">The predictor.</param>
        /// <param name="data">The training examples.</param>
        /// <param name="command">The command string.</param>
        public static void SaveModel(IHostEnvironment env, IChannel ch, Stream outputStream, IPredictor predictor, RoleMappedData data, string command = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(outputStream, nameof(outputStream));
            ch.CheckValueOrNull(predictor);
            ch.CheckValue(data, nameof(data));
            ch.CheckValueOrNull(command);

            using (var ch2 = env.Start("SaveModel"))
                using (var pch = env.StartProgressChannel("Saving model"))
                {
                    using (var rep = RepositoryWriter.CreateNew(outputStream, ch2))
                    {
                        if (predictor != null)
                        {
                            ch2.Trace("Saving predictor");
                            ModelSaveContext.SaveModel(rep, predictor, ModelFileUtils.DirPredictor);
                        }

                        ch2.Trace("Saving loader and transformations");
                        var dataPipe = data.Data;
                        if (dataPipe is IDataLoader)
                        {
                            ModelSaveContext.SaveModel(rep, dataPipe, ModelFileUtils.DirDataLoaderModel);
                        }
                        else
                        {
                            SaveDataPipe(env, rep, dataPipe);
                        }

                        // REVIEW: Handle statistics.
                        // ModelSaveContext.SaveModel(rep, dataStats, DirDataStats);
                        if (!string.IsNullOrWhiteSpace(command))
                        {
                            using (var ent = rep.CreateEntry(ModelFileUtils.DirTrainingInfo, "Command.txt"))
                                using (var writer = Utils.OpenWriter(ent.Stream))
                                    writer.WriteLine(command);
                        }
                        ModelFileUtils.SaveRoleMappings(env, ch, data.Schema, rep);

                        rep.Commit();
                    }
                    ch2.Done();
                }
        }
コード例 #6
0
        private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer trainer, ref RoleMappedData data, bool?cacheData)
        {
            Contracts.AssertValue(env, nameof(env));
            env.AssertValue(ch, nameof(ch));
            ch.AssertValue(trainer, nameof(trainer));
            ch.AssertValue(data, nameof(data));

            ITrainerEx trainerEx   = trainer as ITrainerEx;
            bool       shouldCache = cacheData ?? (!(data.Data is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching));

            if (shouldCache)
            {
                ch.Trace("Caching");
                var prefetch  = data.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
                var cacheView = new CacheDataView(env, data.Data, prefetch);
                // Because the prefetching worked, we know that these are valid columns.
                data = new RoleMappedData(cacheView, data.Schema.GetColumnRoleNames());
            }
            else
            {
                ch.Trace("Not caching");
            }
            return(shouldCache);
        }
コード例 #7
0
        public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
                                       SubComponent <ICalibratorTrainer, SignatureCalibrator> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inpPredictor = null)
        {
            ICalibratorTrainer caliTrainer = !calibrator.IsGood() ? null : calibrator.CreateInstance(env);

            return(TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inpPredictor));
        }
コード例 #8
0
        private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
                                            ICalibratorTrainer calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inpPredictor = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(data, nameof(data));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckNonEmpty(name, nameof(name));
            ch.CheckValueOrNull(validData);
            ch.CheckValueOrNull(inpPredictor);

            var trainerRmd = trainer as ITrainer <RoleMappedData>;

            if (trainerRmd == null)
            {
                throw ch.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Trainer '{0}' does not accept known training data type", name);
            }

            Action <IChannel, ITrainer, Action <object>, object, object, object> trainCoreAction = TrainCore;
            IPredictor predictor;

            AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
            ch.Trace("Training");
            if (validData != null)
            {
                AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
            }

            var genericExam = trainCoreAction.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(
                typeof(RoleMappedData),
                inpPredictor != null ? inpPredictor.GetType() : typeof(IPredictor));
            Action <RoleMappedData> trainExam = trainerRmd.Train;

            genericExam.Invoke(null, new object[] { ch, trainerRmd, trainExam, data, validData, inpPredictor });

            ch.Trace("Constructing predictor");
            predictor = trainerRmd.CreatePredictor();
            return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data));
        }
コード例 #9
0
 public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
                                IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null, RoleMappedData testData = null)
 {
     return(TrainCore(env, ch, data, trainer, validData, calibrator, maxCalibrationExamples, cacheData, inputPredictor, testData));
 }
コード例 #10
0
        private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
                                            IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null, RoleMappedData testData = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(data, nameof(data));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckValueOrNull(validData);
            ch.CheckValueOrNull(inputPredictor);

            AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
            ch.Trace("Training");
            if (validData != null)
            {
                AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
            }

            if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining)
            {
                ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
                           ": Trainer does not support incremental training.");
                inputPredictor = null;
            }
            ch.Assert(validData == null || trainer.Info.SupportsValidation);
            var predictor   = trainer.Train(new TrainContext(data, validData, testData, inputPredictor));
            var caliTrainer = calibrator?.CreateComponent(env);

            return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, caliTrainer, maxCalibrationExamples, trainer, predictor, data));
        }
コード例 #11
0
 public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer,
                                IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples)
 {
     return(TrainCore(env, ch, data, trainer, null, calibrator, maxCalibrationExamples, false));
 }
コード例 #12
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = _trainer.CreateComponent(Host);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataView view = CreateLoader();

            ISchema schema  = view.Schema;
            var     label   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), _labelColumn, DefaultColumnNames.Label);
            var     feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), _featureColumn, DefaultColumnNames.Features);
            var     group   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), _groupColumn, DefaultColumnNames.GroupId);
            var     weight  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), _weightColumn, DefaultColumnNames.Weight);
            var     name    = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), _nameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref view, feature, Args.NormalizeFeatures);

            ch.Trace("Binding columns");

            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var data       = new RoleMappedData(view, label, feature, group, weight, name, customCols);

            // REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands.
            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
            {
                if (!trainer.Info.SupportsValidation)
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe);
                    validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            // In addition to the training set, some trainers can accept two extra data sets, validation set and test set,
            // in training phase. The major difference between validation set and test set is that training process may
            // indirectly use validation set to improve the model but the learned model should totally independent of test set.
            // Similar to validation set, the trainer can report the scores computed using test set.
            RoleMappedData testDataUsedInTrainer = null;

            if (!string.IsNullOrWhiteSpace(Args.TestFile))
            {
                // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided
                // because this is TrainTest command.
                if (trainer.Info.SupportsTest)
                {
                    ch.Trace("Constructing the test pipeline");
                    IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: Args.TestFile);
                    testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, testPipeUsedInTrainer);
                    testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
                                             Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer);

            using (var file = Host.CreateOutputFile(Args.OutputModelFile))
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
        }
コード例 #13
0
 public override IDataTransform GetPerInstanceMetrics(RoleMappedData data)
 {
     return(NopTransform.CreateIfNeeded(Host, data.Data));
 }
コード例 #14
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = Args.Trainer.CreateComponent(Host);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing the training pipeline");
            IDataView trainPipe = CreateLoader();

            var    schema = trainPipe.Schema;
            string label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                                                                Args.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
                                                                  Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                                                               Args.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                                                                Args.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn),
                                                              Args.NameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures);

            ch.Trace("Binding columns");
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var data       = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols);

            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
            {
                if (!trainer.Info.SupportsValidation)
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
                    validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            // In addition to the training set, some trainers can accept two data sets, validation set and test set,
            // in training phase. The major difference between validation set and test set is that training process may
            // indirectly use validation set to improve the model but the learned model should totally independent of test set.
            // Similar to validation set, the trainer can report the scores computed using test set.
            RoleMappedData testDataUsedInTrainer = null;

            if (!string.IsNullOrWhiteSpace(Args.TestFile))
            {
                // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided
                // because this is TrainTest command.
                if (trainer.Info.SupportsTest)
                {
                    ch.Trace("Constructing the test pipeline");
                    IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: Args.TestFile);
                    testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer);
                    testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
                                             Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer);

            IDataLoader testPipe;
            bool        hasOutfile   = !string.IsNullOrEmpty(Args.OutputModelFile);
            var         tempFilePath = hasOutfile ? null : Path.GetTempFileName();

            using (var file = new SimpleFileHandle(ch, hasOutfile ? Args.OutputModelFile : tempFilePath, true, !hasOutfile))
            {
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
                ch.Trace("Constructing the testing pipeline");
                using (var stream = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(stream, ch))
                        testPipe = LoadLoader(rep, Args.TestFile, true);
            }

            // Score.
            ch.Trace("Scoring and evaluating");
            ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line.");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema);

            // Evaluate.
            var evaluator = Args.Evaluator?.CreateComponent(Host) ??
                            EvaluateUtils.GetEvaluator(Host, scorePipe.Schema);
            var dataEval = new RoleMappedData(scorePipe, label, features,
                                              group, weight, name, customCols, opt: true);
            var metrics = evaluator.Evaluate(dataEval);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
            {
                throw ch.Except("No overall metrics found");
            }
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(dataEval);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
            }
        }
コード例 #15
0
        private void RunCore(IChannel ch)
        {
            ch.Trace("Constructing data pipeline");
            IDataLoader      loader;
            IPredictor       predictor;
            RoleMappedSchema trainSchema;

            LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader);
            ch.AssertValue(predictor);
            ch.AssertValueOrNull(trainSchema);
            ch.AssertValue(loader);

            ch.Trace("Binding columns");
            ISchema schema = loader.Schema;
            string  label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn),
                                                                 Args.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.FeatureColumn),
                                                                  Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn),
                                                               Args.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.WeightColumn),
                                                                Args.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.NameColumn),
                                                              Args.NameColumn, DefaultColumnNames.Name);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);

            // Score.
            ch.Trace("Scoring and evaluating");
            ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TestCommand should only be used from the command line.");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, loader, features, group, customCols, Host, trainSchema);

            // Evaluate.
            var evalComp = Args.Evaluator;

            if (!evalComp.IsGood())
            {
                evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
            }
            var evaluator = evalComp.CreateInstance(Host);
            var data      = new RoleMappedData(scorePipe, label, null, group, weight, name, customCols);
            var metrics   = evaluator.Evaluate(data);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
            {
                throw ch.Except("No overall metrics found");
            }
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(data);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
            }
        }