private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = _trainer.CreateComponent(Host); IPredictor inputPredictor = null; if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing data pipeline"); IDataView view = CreateLoader(); var schema = view.Schema; var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), _labelColumn, DefaultColumnNames.Label); var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), _featureColumn, DefaultColumnNames.Features); var group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), _groupColumn, DefaultColumnNames.GroupId); var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), _weightColumn, DefaultColumnNames.Weight); var name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), _nameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref view, feature, ImplOptions.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns); var data = new RoleMappedData(view, label, feature, group, weight, name, customCols); // REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands. RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile)) { if (!trainer.Info.SupportsValidation) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } // In addition to the training set, some trainers can accept two extra data sets, validation set and test set, // in training phase. The major difference between validation set and test set is that training process may // indirectly use validation set to improve the model but the learned model should totally independent of test set. // Similar to validation set, the trainer can report the scores computed using test set. RoleMappedData testDataUsedInTrainer = null; if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile)) { // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided // because this is TrainTest command. if (trainer.Info.SupportsTest) { ch.Trace("Constructing the test pipeline"); IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile); testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, testPipeUsedInTrainer); testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, validData, ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer); using (var file = Host.CreateOutputFile(ImplOptions.OutputModelFile)) TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = ImplOptions.Trainer.CreateComponent(Host); IPredictor inputPredictor = null; if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing the training pipeline"); IDataView trainPipe = CreateLoader(); var schema = trainPipe.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), ImplOptions.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), ImplOptions.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, ImplOptions.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns); var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols); RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile)) { if (!trainer.Info.SupportsValidation) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } // In addition to the training set, some trainers can accept two data sets, validation set and test set, // in training phase. The major difference between validation set and test set is that training process may // indirectly use validation set to improve the model but the learned model should totally independent of test set. // Similar to validation set, the trainer can report the scores computed using test set. RoleMappedData testDataUsedInTrainer = null; if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile)) { // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided // because this is TrainTest command. if (trainer.Info.SupportsTest) { ch.Trace("Constructing the test pipeline"); IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile); testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer); testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, validData, ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer); ILegacyDataLoader testPipe; bool hasOutfile = !string.IsNullOrEmpty(ImplOptions.OutputModelFile); var tempFilePath = hasOutfile ? null : Path.GetTempFileName(); using (var file = new SimpleFileHandle(ch, hasOutfile ? ImplOptions.OutputModelFile : tempFilePath, true, !hasOutfile)) { TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); ch.Trace("Constructing the testing pipeline"); using (var stream = file.OpenReadStream()) using (var rep = RepositoryReader.Open(stream, ch)) testPipe = LoadLoader(rep, ImplOptions.TestFile, true); } // Score. ch.Trace("Scoring and evaluating"); ch.Assert(ImplOptions.Scorer == null || ImplOptions.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line."); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(ImplOptions.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); // Evaluate. var evaluator = ImplOptions.Evaluator?.CreateComponent(Host) ?? EvaluateUtils.GetEvaluator(Host, scorePipe.Schema); var dataEval = new RoleMappedData(scorePipe, label, features, group, weight, name, customCols, opt: true); var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) { throw ch.Except("No overall metrics found"); } overall = evaluator.GetOverallResults(overall); MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary <string, IDataView>[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(dataEval); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, idv); } }
private void RunCore(IChannel ch) { Host.AssertValue(ch); ch.Trace("Creating loader"); LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader); ch.AssertValue(predictor); ch.AssertValueOrNull(trainSchema); ch.AssertValue(loader); ch.Trace("Creating pipeline"); var scorer = ImplOptions.Scorer; ch.Assert(scorer == null || scorer is ICommandLineComponentFactory, "ScoreCommand should only be used from the command line."); var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorerFactorySettings: scorer as ICommandLineComponentFactory); ch.AssertValue(bindable); // REVIEW: We probably ought to prefer role mappings from the training schema. string feat = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(ImplOptions.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns); var schema = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true); var mapper = bindable.Bind(Host, schema); if (scorer == null) { scorer = ScoreUtils.GetScorerComponent(Host, mapper); } loader = CompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(), (env, view) => scorer.CreateComponent(env, view, mapper, trainSchema)); loader = CompositeDataLoader.Create(Host, loader, ImplOptions.PostTransform); if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile)) { ch.Trace("Saving the data pipe"); SaveLoader(loader, ImplOptions.OutputModelFile); } ch.Trace("Creating saver"); IDataSaver writer; if (ImplOptions.Saver == null) { var ext = Path.GetExtension(ImplOptions.OutputDataFile); var isText = ext == ".txt" || ext == ".tlc"; if (isText) { writer = new TextSaver(Host, new TextSaver.Arguments()); } else { writer = new BinarySaver(Host, new BinarySaver.Arguments()); } } else { writer = ImplOptions.Saver.CreateComponent(Host); } ch.Assert(writer != null); var outputIsBinary = writer is BinaryWriter; bool outputAllColumns = ImplOptions.OutputAllColumns == true || (ImplOptions.OutputAllColumns == null && Utils.Size(ImplOptions.OutputColumns) == 0 && outputIsBinary); bool outputNamesAndLabels = ImplOptions.OutputAllColumns == true || Utils.Size(ImplOptions.OutputColumns) == 0; if (ImplOptions.OutputAllColumns == true && Utils.Size(ImplOptions.OutputColumns) != 0) { ch.Warning(nameof(ImplOptions.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(ImplOptions.OutputColumns) + " specified."); } if (!outputAllColumns && Utils.Size(ImplOptions.OutputColumns) != 0) { foreach (var outCol in ImplOptions.OutputColumns) { if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex)) { throw ch.ExceptUserArg(nameof(Arguments.OutputColumns), "Column '{0}' not found.", outCol); } } } uint maxScoreId = 0; if (!outputAllColumns) { maxScoreId = loader.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId); } ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based var cols = new List <int>(); for (int i = 0; i < loader.Schema.Count; i++) { if (!ImplOptions.KeepHidden && loader.Schema[i].IsHidden) { continue; } if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels))) { continue; } var type = loader.Schema[i].Type; if (writer.IsColumnSavable(type)) { cols.Add(i); } else { ch.Warning("The column '{0}' will not be written as it has unsavable column type.", loader.Schema[i].Name); } } ch.Check(cols.Count > 0, "No valid columns to save"); ch.Trace("Scoring and saving data"); using (var file = Host.CreateOutputFile(ImplOptions.OutputDataFile)) using (var stream = file.CreateWriteStream()) writer.SaveData(stream, loader, cols.ToArray()); }