示例#1
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = _trainer.CreateComponent(Host);

            IPredictor inputPredictor = null;

            if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataView view = CreateLoader();

            var schema  = view.Schema;
            var label   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), _labelColumn, DefaultColumnNames.Label);
            var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), _featureColumn, DefaultColumnNames.Features);
            var group   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), _groupColumn, DefaultColumnNames.GroupId);
            var weight  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), _weightColumn, DefaultColumnNames.Weight);
            var name    = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), _nameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref view, feature, ImplOptions.NormalizeFeatures);

            ch.Trace("Binding columns");

            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
            var data       = new RoleMappedData(view, label, feature, group, weight, name, customCols);

            // REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands.
            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile))
            {
                if (!trainer.Info.SupportsValidation)
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe);
                    validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            // In addition to the training set, some trainers can accept two extra data sets, validation set and test set,
            // in training phase. The major difference between validation set and test set is that training process may
            // indirectly use validation set to improve the model but the learned model should totally independent of test set.
            // Similar to validation set, the trainer can report the scores computed using test set.
            RoleMappedData testDataUsedInTrainer = null;

            if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile))
            {
                // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided
                // because this is TrainTest command.
                if (trainer.Info.SupportsTest)
                {
                    ch.Trace("Constructing the test pipeline");
                    IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile);
                    testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, testPipeUsedInTrainer);
                    testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
                                             ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer);

            using (var file = Host.CreateOutputFile(ImplOptions.OutputModelFile))
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
        }
示例#2
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = ImplOptions.Trainer.CreateComponent(Host);

            IPredictor inputPredictor = null;

            if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing the training pipeline");
            IDataView trainPipe = CreateLoader();

            var    schema = trainPipe.Schema;
            string label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                                                                ImplOptions.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
                                                                  ImplOptions.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                                                               ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                                                                ImplOptions.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn),
                                                              ImplOptions.NameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, ImplOptions.NormalizeFeatures);

            ch.Trace("Binding columns");
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
            var data       = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols);

            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile))
            {
                if (!trainer.Info.SupportsValidation)
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
                    validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            // In addition to the training set, some trainers can accept two data sets, validation set and test set,
            // in training phase. The major difference between validation set and test set is that training process may
            // indirectly use validation set to improve the model but the learned model should totally independent of test set.
            // Similar to validation set, the trainer can report the scores computed using test set.
            RoleMappedData testDataUsedInTrainer = null;

            if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile))
            {
                // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided
                // because this is TrainTest command.
                if (trainer.Info.SupportsTest)
                {
                    ch.Trace("Constructing the test pipeline");
                    IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile);
                    testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer);
                    testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
                                             ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer);

            ILegacyDataLoader testPipe;
            bool hasOutfile   = !string.IsNullOrEmpty(ImplOptions.OutputModelFile);
            var  tempFilePath = hasOutfile ? null : Path.GetTempFileName();

            using (var file = new SimpleFileHandle(ch, hasOutfile ? ImplOptions.OutputModelFile : tempFilePath, true, !hasOutfile))
            {
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
                ch.Trace("Constructing the testing pipeline");
                using (var stream = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(stream, ch))
                        testPipe = LoadLoader(rep, ImplOptions.TestFile, true);
            }

            // Score.
            ch.Trace("Scoring and evaluating");
            ch.Assert(ImplOptions.Scorer == null || ImplOptions.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line.");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(ImplOptions.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema);

            // Evaluate.
            var evaluator = ImplOptions.Evaluator?.CreateComponent(Host) ??
                            EvaluateUtils.GetEvaluator(Host, scorePipe.Schema);
            var dataEval = new RoleMappedData(scorePipe, label, features,
                                              group, weight, name, customCols, opt: true);
            var metrics = evaluator.Evaluate(dataEval);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
            {
                throw ch.Except("No overall metrics found");
            }
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(dataEval);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, idv);
            }
        }
示例#3
0
        private void RunCore(IChannel ch)
        {
            Host.AssertValue(ch);

            ch.Trace("Creating loader");

            LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader);
            ch.AssertValue(predictor);
            ch.AssertValueOrNull(trainSchema);
            ch.AssertValue(loader);

            ch.Trace("Creating pipeline");
            var scorer = ImplOptions.Scorer;

            ch.Assert(scorer == null || scorer is ICommandLineComponentFactory, "ScoreCommand should only be used from the command line.");
            var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorerFactorySettings: scorer as ICommandLineComponentFactory);

            ch.AssertValue(bindable);

            // REVIEW: We probably ought to prefer role mappings from the training schema.
            string feat = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
                                                              nameof(ImplOptions.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
                                                               nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
            var schema     = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true);
            var mapper     = bindable.Bind(Host, schema);

            if (scorer == null)
            {
                scorer = ScoreUtils.GetScorerComponent(Host, mapper);
            }

            loader = CompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(),
                                                        (env, view) => scorer.CreateComponent(env, view, mapper, trainSchema));

            loader = CompositeDataLoader.Create(Host, loader, ImplOptions.PostTransform);

            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
            {
                ch.Trace("Saving the data pipe");
                SaveLoader(loader, ImplOptions.OutputModelFile);
            }

            ch.Trace("Creating saver");
            IDataSaver writer;

            if (ImplOptions.Saver == null)
            {
                var ext    = Path.GetExtension(ImplOptions.OutputDataFile);
                var isText = ext == ".txt" || ext == ".tlc";
                if (isText)
                {
                    writer = new TextSaver(Host, new TextSaver.Arguments());
                }
                else
                {
                    writer = new BinarySaver(Host, new BinarySaver.Arguments());
                }
            }
            else
            {
                writer = ImplOptions.Saver.CreateComponent(Host);
            }
            ch.Assert(writer != null);
            var outputIsBinary = writer is BinaryWriter;

            bool outputAllColumns =
                ImplOptions.OutputAllColumns == true ||
                (ImplOptions.OutputAllColumns == null && Utils.Size(ImplOptions.OutputColumns) == 0 && outputIsBinary);

            bool outputNamesAndLabels =
                ImplOptions.OutputAllColumns == true || Utils.Size(ImplOptions.OutputColumns) == 0;

            if (ImplOptions.OutputAllColumns == true && Utils.Size(ImplOptions.OutputColumns) != 0)
            {
                ch.Warning(nameof(ImplOptions.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(ImplOptions.OutputColumns) + " specified.");
            }

            if (!outputAllColumns && Utils.Size(ImplOptions.OutputColumns) != 0)
            {
                foreach (var outCol in ImplOptions.OutputColumns)
                {
                    if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex))
                    {
                        throw ch.ExceptUserArg(nameof(Arguments.OutputColumns), "Column '{0}' not found.", outCol);
                    }
                }
            }

            uint maxScoreId = 0;

            if (!outputAllColumns)
            {
                maxScoreId = loader.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId);
            }
            ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based
            var cols = new List <int>();

            for (int i = 0; i < loader.Schema.Count; i++)
            {
                if (!ImplOptions.KeepHidden && loader.Schema[i].IsHidden)
                {
                    continue;
                }
                if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels)))
                {
                    continue;
                }
                var type = loader.Schema[i].Type;
                if (writer.IsColumnSavable(type))
                {
                    cols.Add(i);
                }
                else
                {
                    ch.Warning("The column '{0}' will not be written as it has unsavable column type.",
                               loader.Schema[i].Name);
                }
            }

            ch.Check(cols.Count > 0, "No valid columns to save");

            ch.Trace("Scoring and saving data");
            using (var file = Host.CreateOutputFile(ImplOptions.OutputDataFile))
                using (var stream = file.CreateWriteStream())
                    writer.SaveData(stream, loader, cols.ToArray());
        }