private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = _trainer.CreateComponent(Host); IPredictor inputPredictor = null; if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing data pipeline"); IDataView view = CreateLoader(); var schema = view.Schema; var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), _labelColumn, DefaultColumnNames.Label); var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), _featureColumn, DefaultColumnNames.Features); var group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), _groupColumn, DefaultColumnNames.GroupId); var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), _weightColumn, DefaultColumnNames.Weight); var name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), _nameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref view, feature, Args.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); var data = new RoleMappedData(view, label, feature, group, weight, name, customCols); // REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands. RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) { if (!trainer.Info.SupportsValidation) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } // In addition to the training set, some trainers can accept two extra data sets, validation set and test set, // in training phase. The major difference between validation set and test set is that training process may // indirectly use validation set to improve the model but the learned model should totally independent of test set. // Similar to validation set, the trainer can report the scores computed using test set. RoleMappedData testDataUsedInTrainer = null; if (!string.IsNullOrWhiteSpace(Args.TestFile)) { // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided // because this is TrainTest command. if (trainer.Info.SupportsTest) { ch.Trace("Constructing the test pipeline"); IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: Args.TestFile); testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, testPipeUsedInTrainer); testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, validData, Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer); using (var file = Host.CreateOutputFile(Args.OutputModelFile)) TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = ImplOptions.Trainer.CreateComponent(Host); IPredictor inputPredictor = null; if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing the training pipeline"); IDataView trainPipe = CreateLoader(); var schema = trainPipe.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), ImplOptions.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), ImplOptions.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, ImplOptions.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns); var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols); RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile)) { if (!trainer.Info.SupportsValidation) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } // In addition to the training set, some trainers can accept two data sets, validation set and test set, // in training phase. The major difference between validation set and test set is that training process may // indirectly use validation set to improve the model but the learned model should totally independent of test set. // Similar to validation set, the trainer can report the scores computed using test set. RoleMappedData testDataUsedInTrainer = null; if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile)) { // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided // because this is TrainTest command. if (trainer.Info.SupportsTest) { ch.Trace("Constructing the test pipeline"); IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile); testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer); testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, validData, ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer); ILegacyDataLoader testPipe; bool hasOutfile = !string.IsNullOrEmpty(ImplOptions.OutputModelFile); var tempFilePath = hasOutfile ? null : Path.GetTempFileName(); using (var file = new SimpleFileHandle(ch, hasOutfile ? ImplOptions.OutputModelFile : tempFilePath, true, !hasOutfile)) { TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); ch.Trace("Constructing the testing pipeline"); using (var stream = file.OpenReadStream()) using (var rep = RepositoryReader.Open(stream, ch)) testPipe = LoadLoader(rep, ImplOptions.TestFile, true); } // Score. ch.Trace("Scoring and evaluating"); ch.Assert(ImplOptions.Scorer == null || ImplOptions.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line."); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(ImplOptions.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); // Evaluate. var evaluator = ImplOptions.Evaluator?.CreateComponent(Host) ?? EvaluateUtils.GetEvaluator(Host, scorePipe.Schema); var dataEval = new RoleMappedData(scorePipe, label, features, group, weight, name, customCols, opt: true); var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) { throw ch.Except("No overall metrics found"); } overall = evaluator.GetOverallResults(overall); MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary <string, IDataView>[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(dataEval); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, idv); } }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); IPredictor inputPredictor = null; if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing data pipeline"); IDataLoader loader = CreateRawLoader(); // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform. var preXf = Args.PreTransforms; if (!string.IsNullOrEmpty(Args.OutputDataFile)) { string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name); if (name == null) { preXf = preXf.Concat( new[] { new KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >( "", ComponentFactoryUtils.CreateFromFunction <IDataView, IDataTransform>( (env, input) => { var args = new GenerateNumberTransform.Options(); args.Columns = new[] { new GenerateNumberTransform.Column() { Name = DefaultColumnNames.Name }, }; args.UseCounter = true; return(new GenerateNumberTransform(env, args, input)); })) }).ToArray(); } } loader = CompositeDataLoader.Create(Host, loader, preXf); ch.Trace("Binding label and features columns"); IDataView pipe = loader; var stratificationColumn = GetSplitColumn(ch, loader, ref pipe); var scorer = Args.Scorer; var evaluator = Args.Evaluator; Func <IDataView> validDataCreator = null; if (Args.ValidationFile != null) { validDataCreator = () => { // Fork the command. var impl = new CrossValidationCommand(this); return(impl.CreateRawLoader(dataFile: Args.ValidationFile)); }; } FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn, Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator, validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile)); var tasks = fold.GetCrossValidationTasks(); var eval = evaluator?.CreateComponent(Host) ?? EvaluateUtils.GetEvaluator(Host, tasks[0].Result.ScoreSchema); // Print confusion matrix and fold results for each fold. for (int i = 0; i < tasks.Length; i++) { var dict = tasks[i].Result.Metrics; MetricWriter.PrintWarnings(ch, dict); eval.PrintFoldResults(ch, dict); } // Print the overall results. if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList)) { throw ch.Except("No overall metrics found"); } var overall = eval.GetOverallResults(overallList.ToArray()); MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds); eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray()); Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray(); SendTelemetryMetric(metricValues); // Save the per-instance results. if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics, Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames); if (variableSizeVectorColumnNames.Length > 0) { ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.", string.Join(", ", variableSizeVectorColumnNames)); } if (Args.CollateMetrics) { ch.Assert(perInstance.Length == 1); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]); } else { int i = 0; foreach (var idv in perInstance) { MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv); i++; } } } }