private FoldResult RunFold(int fold) { var host = GetHost(); host.Assert(0 <= fold && fold <= _numFolds); // REVIEW: Make channels buffered in multi-threaded environments. using (var ch = host.Start($"Fold {fold}")) { ch.Trace("Constructing trainer"); ITrainer trainer = _trainer.CreateComponent(host); // Train pipe. var trainFilter = new RangeFilter.Arguments(); trainFilter.Column = _splitColumn; trainFilter.Min = (Double)fold / _numFolds; trainFilter.Max = (Double)(fold + 1) / _numFolds; trainFilter.Complement = true; IDataView trainPipe = new RangeFilter(host, trainFilter, _inputDataView); trainPipe = new OpaqueDataView(trainPipe); var trainData = _createExamples(host, ch, trainPipe, trainer); // Test pipe. var testFilter = new RangeFilter.Arguments(); testFilter.Column = trainFilter.Column; testFilter.Min = trainFilter.Min; testFilter.Max = trainFilter.Max; ch.Assert(!testFilter.Complement); IDataView testPipe = new RangeFilter(host, testFilter, _inputDataView); testPipe = new OpaqueDataView(testPipe); var testData = _applyTransformsToTestData(host, ch, testPipe, trainData, trainPipe); // Validation pipe and examples. RoleMappedData validData = null; if (_getValidationDataView != null) { ch.Assert(_applyTransformsToValidationData != null); if (!trainer.Info.SupportsValidation) { ch.Warning("Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validLoader = _getValidationDataView(); var validPipe = ApplyTransformUtils.ApplyAllTransformsToData(host, _inputDataView, validLoader); validPipe = new OpaqueDataView(validPipe); validData = _applyTransformsToValidationData(host, ch, validPipe, trainData, trainPipe); } } // Train. var predictor = TrainUtils.Train(host, ch, trainData, trainer, validData, _calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor); // Score. ch.Trace("Scoring and evaluating"); ch.Assert(_scorer == null || _scorer is ICommandLineComponentFactory, "CrossValidationCommand should only be used from the command line."); var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: _scorer as ICommandLineComponentFactory); ch.AssertValue(bindable); var mapper = bindable.Bind(host, testData.Schema); var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(host, mapper); IDataScorerTransform scorePipe = scorerComp.CreateComponent(host, testData.Data, mapper, trainData.Schema); // Save per-fold model. string modelFileName = ConstructPerFoldName(_outputModelFile, fold); if (modelFileName != null && _loader != null) { using (var file = host.CreateOutputFile(modelFileName)) { var rmd = new RoleMappedData( CompositeDataLoader.ApplyTransform(host, _loader, null, null, (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)), trainData.Schema.GetColumnRoleNames()); TrainUtils.SaveModel(host, ch, file, predictor, rmd, _cmd); } } // Evaluate. var eval = _evaluator?.CreateComponent(host) ?? EvaluateUtils.GetEvaluator(host, scorePipe.Schema); // Note that this doesn't require the provided columns to exist (because of the "opt" parameter). // We don't normally expect the scorer to drop columns, but if it does, we should not require // all the columns in the test pipeline to still be present. var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true); var dict = eval.Evaluate(dataEval); RoleMappedData perInstance = null; if (_savePerInstance) { var perInst = eval.GetPerInstanceMetrics(dataEval); perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true); } return(new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema)); } }
public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema) { var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, null, out var mapper); return(sc.CreateComponent(env, data.Data, mapper, trainSchema)); }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = Args.Trainer.CreateInstance(Host); IPredictor inputPredictor = null; if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing the training pipeline"); IDataView trainPipe = CreateLoader(); ISchema schema = trainPipe.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), Args.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), Args.NameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); var data = TrainUtils.CreateExamples(trainPipe, label, features, group, weight, name, customCols); RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) { if (!TrainUtils.CanUseValidationData(trainer)) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe); validData = RoleMappedData.Create(validPipe, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData, Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor); IDataLoader testPipe; using (var file = !string.IsNullOrEmpty(Args.OutputModelFile) ? Host.CreateOutputFile(Args.OutputModelFile) : Host.CreateTempFile(".zip")) { TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); ch.Trace("Constructing the testing pipeline"); using (var stream = file.OpenReadStream()) using (var rep = RepositoryReader.Open(stream, ch)) testPipe = LoadLoader(rep, Args.TestFile, true); } // Score. ch.Trace("Scoring and evaluating"); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); // Evaluate. var evalComp = Args.Evaluator; if (!evalComp.IsGood()) { evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema); } var evaluator = evalComp.CreateInstance(Host); var dataEval = TrainUtils.CreateExamplesOpt(scorePipe, label, features, group, weight, name, customCols); var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); evaluator.PrintOverallResults(ch, Args.SummaryFilename, metrics); Dictionary <string, IDataView>[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(dataEval); var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); } }
public FoldResult(Dictionary <string, IDataView> metrics, Schema scoreSchema, RoleMappedData perInstance, RoleMappedSchema trainSchema) { Metrics = metrics; ScoreSchema = scoreSchema; PerInstanceResults = perInstance; TrainSchema = trainSchema; }
/// <summary> /// Save the model to the stream. /// The method saves the loader and the transformations of dataPipe and saves optionally predictor /// and command. It also uses featureColumn, if provided, to extract feature names. /// </summary> /// <param name="env">The host environment to use.</param> /// <param name="ch">The communication channel to use.</param> /// <param name="outputStream">The output model stream.</param> /// <param name="predictor">The predictor.</param> /// <param name="data">The training examples.</param> /// <param name="command">The command string.</param> public static void SaveModel(IHostEnvironment env, IChannel ch, Stream outputStream, IPredictor predictor, RoleMappedData data, string command = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(outputStream, nameof(outputStream)); ch.CheckValueOrNull(predictor); ch.CheckValue(data, nameof(data)); ch.CheckValueOrNull(command); using (var ch2 = env.Start("SaveModel")) using (var pch = env.StartProgressChannel("Saving model")) { using (var rep = RepositoryWriter.CreateNew(outputStream, ch2)) { if (predictor != null) { ch2.Trace("Saving predictor"); ModelSaveContext.SaveModel(rep, predictor, ModelFileUtils.DirPredictor); } ch2.Trace("Saving loader and transformations"); var dataPipe = data.Data; if (dataPipe is IDataLoader) { ModelSaveContext.SaveModel(rep, dataPipe, ModelFileUtils.DirDataLoaderModel); } else { SaveDataPipe(env, rep, dataPipe); } // REVIEW: Handle statistics. // ModelSaveContext.SaveModel(rep, dataStats, DirDataStats); if (!string.IsNullOrWhiteSpace(command)) { using (var ent = rep.CreateEntry(ModelFileUtils.DirTrainingInfo, "Command.txt")) using (var writer = Utils.OpenWriter(ent.Stream)) writer.WriteLine(command); } ModelFileUtils.SaveRoleMappings(env, ch, data.Schema, rep); rep.Commit(); } ch2.Done(); } }
private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer trainer, ref RoleMappedData data, bool?cacheData) { Contracts.AssertValue(env, nameof(env)); env.AssertValue(ch, nameof(ch)); ch.AssertValue(trainer, nameof(trainer)); ch.AssertValue(data, nameof(data)); ITrainerEx trainerEx = trainer as ITrainerEx; bool shouldCache = cacheData ?? (!(data.Data is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching)); if (shouldCache) { ch.Trace("Caching"); var prefetch = data.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray(); var cacheView = new CacheDataView(env, data.Data, prefetch); // Because the prefetching worked, we know that these are valid columns. data = new RoleMappedData(cacheView, data.Schema.GetColumnRoleNames()); } else { ch.Trace("Not caching"); } return(shouldCache); }
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData, SubComponent <ICalibratorTrainer, SignatureCalibrator> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inpPredictor = null) { ICalibratorTrainer caliTrainer = !calibrator.IsGood() ? null : calibrator.CreateInstance(env); return(TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inpPredictor)); }
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData, ICalibratorTrainer calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inpPredictor = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(data, nameof(data)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckNonEmpty(name, nameof(name)); ch.CheckValueOrNull(validData); ch.CheckValueOrNull(inpPredictor); var trainerRmd = trainer as ITrainer <RoleMappedData>; if (trainerRmd == null) { throw ch.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Trainer '{0}' does not accept known training data type", name); } Action <IChannel, ITrainer, Action <object>, object, object, object> trainCoreAction = TrainCore; IPredictor predictor; AddCacheIfWanted(env, ch, trainer, ref data, cacheData); ch.Trace("Training"); if (validData != null) { AddCacheIfWanted(env, ch, trainer, ref validData, cacheData); } var genericExam = trainCoreAction.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod( typeof(RoleMappedData), inpPredictor != null ? inpPredictor.GetType() : typeof(IPredictor)); Action <RoleMappedData> trainExam = trainerRmd.Train; genericExam.Invoke(null, new object[] { ch, trainerRmd, trainExam, data, validData, inpPredictor }); ch.Trace("Constructing predictor"); predictor = trainerRmd.CreatePredictor(); return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data)); }
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData, IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null, RoleMappedData testData = null) { return(TrainCore(env, ch, data, trainer, validData, calibrator, maxCalibrationExamples, cacheData, inputPredictor, testData)); }
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData, IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null, RoleMappedData testData = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(data, nameof(data)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckValueOrNull(validData); ch.CheckValueOrNull(inputPredictor); AddCacheIfWanted(env, ch, trainer, ref data, cacheData); ch.Trace("Training"); if (validData != null) { AddCacheIfWanted(env, ch, trainer, ref validData, cacheData); } if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining) { ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) + ": Trainer does not support incremental training."); inputPredictor = null; } ch.Assert(validData == null || trainer.Info.SupportsValidation); var predictor = trainer.Train(new TrainContext(data, validData, testData, inputPredictor)); var caliTrainer = calibrator?.CreateComponent(env); return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, caliTrainer, maxCalibrationExamples, trainer, predictor, data)); }
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples) { return(TrainCore(env, ch, data, trainer, null, calibrator, maxCalibrationExamples, false)); }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = _trainer.CreateComponent(Host); IPredictor inputPredictor = null; if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing data pipeline"); IDataView view = CreateLoader(); ISchema schema = view.Schema; var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), _labelColumn, DefaultColumnNames.Label); var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), _featureColumn, DefaultColumnNames.Features); var group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), _groupColumn, DefaultColumnNames.GroupId); var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), _weightColumn, DefaultColumnNames.Weight); var name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), _nameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref view, feature, Args.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); var data = new RoleMappedData(view, label, feature, group, weight, name, customCols); // REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands. RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) { if (!trainer.Info.SupportsValidation) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } // In addition to the training set, some trainers can accept two extra data sets, validation set and test set, // in training phase. The major difference between validation set and test set is that training process may // indirectly use validation set to improve the model but the learned model should totally independent of test set. // Similar to validation set, the trainer can report the scores computed using test set. RoleMappedData testDataUsedInTrainer = null; if (!string.IsNullOrWhiteSpace(Args.TestFile)) { // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided // because this is TrainTest command. if (trainer.Info.SupportsTest) { ch.Trace("Constructing the test pipeline"); IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: Args.TestFile); testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, testPipeUsedInTrainer); testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, validData, Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer); using (var file = Host.CreateOutputFile(Args.OutputModelFile)) TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); }
public override IDataTransform GetPerInstanceMetrics(RoleMappedData data) { return(NopTransform.CreateIfNeeded(Host, data.Data)); }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = Args.Trainer.CreateComponent(Host); IPredictor inputPredictor = null; if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing the training pipeline"); IDataView trainPipe = CreateLoader(); var schema = trainPipe.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), Args.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), Args.NameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols); RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) { if (!trainer.Info.SupportsValidation) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } // In addition to the training set, some trainers can accept two data sets, validation set and test set, // in training phase. The major difference between validation set and test set is that training process may // indirectly use validation set to improve the model but the learned model should totally independent of test set. // Similar to validation set, the trainer can report the scores computed using test set. RoleMappedData testDataUsedInTrainer = null; if (!string.IsNullOrWhiteSpace(Args.TestFile)) { // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided // because this is TrainTest command. if (trainer.Info.SupportsTest) { ch.Trace("Constructing the test pipeline"); IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: Args.TestFile); testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer); testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, validData, Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer); IDataLoader testPipe; bool hasOutfile = !string.IsNullOrEmpty(Args.OutputModelFile); var tempFilePath = hasOutfile ? null : Path.GetTempFileName(); using (var file = new SimpleFileHandle(ch, hasOutfile ? Args.OutputModelFile : tempFilePath, true, !hasOutfile)) { TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); ch.Trace("Constructing the testing pipeline"); using (var stream = file.OpenReadStream()) using (var rep = RepositoryReader.Open(stream, ch)) testPipe = LoadLoader(rep, Args.TestFile, true); } // Score. ch.Trace("Scoring and evaluating"); ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line."); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); // Evaluate. var evaluator = Args.Evaluator?.CreateComponent(Host) ?? EvaluateUtils.GetEvaluator(Host, scorePipe.Schema); var dataEval = new RoleMappedData(scorePipe, label, features, group, weight, name, customCols, opt: true); var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) { throw ch.Except("No overall metrics found"); } overall = evaluator.GetOverallResults(overall); MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary <string, IDataView>[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(dataEval); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); } }
private void RunCore(IChannel ch) { ch.Trace("Constructing data pipeline"); IDataLoader loader; IPredictor predictor; RoleMappedSchema trainSchema; LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader); ch.AssertValue(predictor); ch.AssertValueOrNull(trainSchema); ch.AssertValue(loader); ch.Trace("Binding columns"); ISchema schema = loader.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.WeightColumn), Args.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); // Score. ch.Trace("Scoring and evaluating"); ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TestCommand should only be used from the command line."); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, loader, features, group, customCols, Host, trainSchema); // Evaluate. var evalComp = Args.Evaluator; if (!evalComp.IsGood()) { evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema); } var evaluator = evalComp.CreateInstance(Host); var data = new RoleMappedData(scorePipe, label, null, group, weight, name, customCols); var metrics = evaluator.Evaluate(data); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) { throw ch.Except("No overall metrics found"); } overall = evaluator.GetOverallResults(overall); MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary <string, IDataView>[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(data); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); } }