internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx) : base(host, ctx) { FeatureColumn = ctx.LoadStringOrNull(); if (FeatureColumn == null) { FeatureColumnType = null; } else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) { throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); } else { FeatureColumnType = TrainSchema.GetColumnType(col); } BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); }
internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) { Host = host; ctx.LoadModel <TModel, SignatureLoadModel>(host, out TModel model, DirModel); Model = model; // *** Binary format *** // model: prediction model. // stream: empty data view that contains train schema. // id of string: feature column. // Clone the stream with the schema into memory. var ms = new MemoryStream(); ctx.TryLoadBinaryStream(DirTransSchema, reader => { reader.BaseStream.CopyTo(ms); }); ms.Position = 0; var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms); TrainSchema = loader.Schema; FeatureColumn = ctx.LoadStringOrNull(); if (FeatureColumn == null) { FeatureColumnType = null; } else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) { throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); } else { FeatureColumnType = TrainSchema.GetColumnType(col); } BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); }
/// <summary> /// Initializes a new reference of <see cref="SingleFeaturePredictionTransformerBase{TModel, TScorer}"/>. /// </summary> /// <param name="host">The local instance of <see cref="IHost"/>.</param> /// <param name="model">The model used for scoring.</param> /// <param name="trainSchema">The schema of the training data.</param> /// <param name="featureColumn">The feature column name.</param> public SingleFeaturePredictionTransformerBase(IHost host, TModel model, Schema trainSchema, string featureColumn) : base(host, model, trainSchema) { FeatureColumn = featureColumn; FeatureColumn = featureColumn; if (featureColumn == null) { FeatureColumnType = null; } else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) { throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); } else { FeatureColumnType = trainSchema[col].Type; } BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); }
public BindableMapper(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); _env = env; _env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // int: topContributionsCount // int: bottomContributionsCount // bool: normalize // bool: stringify ctx.LoadModel <IFeatureContributionMapper, SignatureLoadModel>(env, out Predictor, ModelFileUtils.DirPredictor); GenericMapper = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null); _topContributionsCount = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom); _bottomContributionsCount = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom); _normalize = ctx.Reader.ReadBoolByte(); Stringify = ctx.Reader.ReadBoolByte(); }
public BindableMapper(IHostEnvironment env, IFeatureContributionMapper predictor, int topContributionsCount, int bottomContributionsCount, bool normalize, bool stringify) { Contracts.CheckValue(env, nameof(env)); _env = env; _env.CheckValue(predictor, nameof(predictor)); if (topContributionsCount <= 0 || topContributionsCount > MaxTopBottom) { throw env.Except($"Number of top contribution must be in range (0,{MaxTopBottom}]"); } if (bottomContributionsCount <= 0 || bottomContributionsCount > MaxTopBottom) { throw env.Except($"Number of bottom contribution must be in range (0,{MaxTopBottom}]"); } _topContributionsCount = topContributionsCount; _bottomContributionsCount = bottomContributionsCount; _normalize = normalize; Stringify = stringify; Predictor = predictor; GenericMapper = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null); }
private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input, IComponentFactory <IPredictor, ISchemaBindableMapper> mapperFactory) { Contracts.AssertValue(env, nameof(env)); env.AssertValue(args, nameof(args)); env.AssertValue(trainer, nameof(trainer)); env.AssertValue(input, nameof(input)); var host = env.Register("TrainAndScoreTransform"); using (var ch = host.Start("Train")) { ch.Trace("Constructing trainer"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); string feat; string group; var data = CreateDataFromArgs(ch, input, args, out feat, out group); var predictor = TrainUtils.Train(host, ch, data, trainer, null, args.Calibrator, args.MaxCalibrationExamples, null); return(ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, data.Schema, mapperFactory)); } }
public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(data, nameof(data)); env.CheckValue(mapper, nameof(mapper)); if (args.Top <= 0 || args.Top > MaxTopBottom) { throw env.Except($"Number of top contribution must be in range (0,{MaxTopBottom}]"); } if (args.Bottom <= 0 || args.Bottom > MaxTopBottom) { throw env.Except($"Number of bottom contribution must be in range (0,{MaxTopBottom}]"); } var contributionMapper = mapper as RowMapper; env.CheckParam(mapper != null, nameof(mapper), "Unexpected mapper"); var scorer = ScoreUtils.GetScorerComponent(env, contributionMapper); var scoredPipe = scorer.CreateComponent(env, data, contributionMapper, trainSchema); return(scoredPipe); }
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckUserArg(!string.IsNullOrWhiteSpace(args.InputModelFile), nameof(args.InputModelFile), "The input model file is required."); IPredictor predictor; RoleMappedSchema trainSchema = null; using (var file = env.OpenInputFile(args.InputModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, env)) { ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(env, out predictor, rep, ModelFileUtils.DirPredictor); trainSchema = ModelFileUtils.LoadRoleMappedSchemaOrNull(env, rep); } string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema, nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema, nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); return(ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, trainSchema)); }
private void RunCore(IChannel ch) { Host.AssertValue(ch); ch.Trace("Creating loader"); LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader); ch.AssertValue(predictor); ch.AssertValueOrNull(trainSchema); ch.AssertValue(loader); ch.Trace("Creating pipeline"); var scorer = Args.Scorer; ch.Assert(scorer == null || scorer is ICommandLineComponentFactory, "ScoreCommand should only be used from the command line."); var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorerFactorySettings: scorer as ICommandLineComponentFactory); ch.AssertValue(bindable); // REVIEW: We probably ought to prefer role mappings from the training schema. string feat = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); var schema = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true); var mapper = bindable.Bind(Host, schema); if (scorer == null) { scorer = ScoreUtils.GetScorerComponent(Host, mapper); } loader = CompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(), (env, view) => scorer.CreateComponent(env, view, mapper, trainSchema)); loader = CompositeDataLoader.Create(Host, loader, Args.PostTransform); if (!string.IsNullOrWhiteSpace(Args.OutputModelFile)) { ch.Trace("Saving the data pipe"); SaveLoader(loader, Args.OutputModelFile); } ch.Trace("Creating saver"); IDataSaver writer; if (Args.Saver == null) { var ext = Path.GetExtension(Args.OutputDataFile); var isText = ext == ".txt" || ext == ".tlc"; if (isText) { writer = new TextSaver(Host, new TextSaver.Arguments()); } else { writer = new BinarySaver(Host, new BinarySaver.Arguments()); } } else { writer = Args.Saver.CreateComponent(Host); } ch.Assert(writer != null); var outputIsBinary = writer is BinaryWriter; bool outputAllColumns = Args.OutputAllColumns == true || (Args.OutputAllColumns == null && Utils.Size(Args.OutputColumn) == 0 && outputIsBinary); bool outputNamesAndLabels = Args.OutputAllColumns == true || Utils.Size(Args.OutputColumn) == 0; if (Args.OutputAllColumns == true && Utils.Size(Args.OutputColumn) != 0) { ch.Warning(nameof(Args.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(Args.OutputColumn) + " specified."); } if (!outputAllColumns && Utils.Size(Args.OutputColumn) != 0) { foreach (var outCol in Args.OutputColumn) { if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex)) { throw ch.ExceptUserArg(nameof(Arguments.OutputColumn), "Column '{0}' not found.", outCol); } } } uint maxScoreId = 0; if (!outputAllColumns) { maxScoreId = loader.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId); } ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based var cols = new List <int>(); for (int i = 0; i < loader.Schema.Count; i++) { if (!Args.KeepHidden && loader.Schema.IsHidden(i)) { continue; } if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels))) { continue; } var type = loader.Schema.GetColumnType(i); if (writer.IsColumnSavable(type)) { cols.Add(i); } else { ch.Warning("The column '{0}' will not be written as it has unsavable column type.", loader.Schema.GetColumnName(i)); } } ch.Check(cols.Count > 0, "No valid columns to save"); ch.Trace("Scoring and saving data"); using (var file = Host.CreateOutputFile(Args.OutputDataFile)) using (var stream = file.CreateWriteStream()) writer.SaveData(stream, loader, cols.ToArray()); }
private FoldResult RunFold(int fold) { var host = GetHost(); host.Assert(0 <= fold && fold <= _numFolds); // REVIEW: Make channels buffered in multi-threaded environments. using (var ch = host.Start($"Fold {fold}")) { ch.Trace("Constructing trainer"); ITrainer trainer = _trainer.CreateInstance(host); // Train pipe. var trainFilter = new RangeFilter.Arguments(); trainFilter.Column = _splitColumn; trainFilter.Min = (Double)fold / _numFolds; trainFilter.Max = (Double)(fold + 1) / _numFolds; trainFilter.Complement = true; IDataView trainPipe = new RangeFilter(host, trainFilter, _inputDataView); trainPipe = new OpaqueDataView(trainPipe); var trainData = _createExamples(host, ch, trainPipe, trainer); // Test pipe. var testFilter = new RangeFilter.Arguments(); testFilter.Column = trainFilter.Column; testFilter.Min = trainFilter.Min; testFilter.Max = trainFilter.Max; ch.Assert(!testFilter.Complement); IDataView testPipe = new RangeFilter(host, testFilter, _inputDataView); testPipe = new OpaqueDataView(testPipe); var testData = _applyTransformsToTestData(host, ch, testPipe, trainData, trainPipe); // Validation pipe and examples. RoleMappedData validData = null; if (_getValidationDataView != null) { ch.Assert(_applyTransformsToValidationData != null); if (!trainer.Info.SupportsValidation) { ch.Warning("Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validLoader = _getValidationDataView(); var validPipe = ApplyTransformUtils.ApplyAllTransformsToData(host, _inputDataView, validLoader); validPipe = new OpaqueDataView(validPipe); validData = _applyTransformsToValidationData(host, ch, validPipe, trainData, trainPipe); } } // Train. var predictor = TrainUtils.Train(host, ch, trainData, trainer, _trainer.Kind, validData, _calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor); // Score. ch.Trace("Scoring and evaluating"); var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer); ch.AssertValue(bindable); var mapper = bindable.Bind(host, testData.Schema); var scorerComp = _scorer.IsGood() ? _scorer : ScoreUtils.GetScorerComponent(mapper); IDataScorerTransform scorePipe = scorerComp.CreateInstance(host, testData.Data, mapper, trainData.Schema); // Save per-fold model. string modelFileName = ConstructPerFoldName(_outputModelFile, fold); if (modelFileName != null && _loader != null) { using (var file = host.CreateOutputFile(modelFileName)) { var rmd = new RoleMappedData( CompositeDataLoader.ApplyTransform(host, _loader, null, null, (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)), trainData.Schema.GetColumnRoleNames()); TrainUtils.SaveModel(host, ch, file, predictor, rmd, _cmd); } } // Evaluate. var evalComp = _evaluator; if (!evalComp.IsGood()) { evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema); } var eval = evalComp.CreateInstance(host); // Note that this doesn't require the provided columns to exist (because of the "opt" parameter). // We don't normally expect the scorer to drop columns, but if it does, we should not require // all the columns in the test pipeline to still be present. var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true); var dict = eval.Evaluate(dataEval); RoleMappedData perInstance = null; if (_savePerInstance) { var perInst = eval.GetPerInstanceMetrics(dataEval); perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true); } ch.Done(); return(new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema)); } }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = Args.Trainer.CreateInstance(Host); IPredictor inputPredictor = null; if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing the training pipeline"); IDataView trainPipe = CreateLoader(); ISchema schema = trainPipe.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), Args.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), Args.NameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols); RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) { if (!TrainUtils.CanUseValidationData(trainer)) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData, Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor); IDataLoader testPipe; using (var file = !string.IsNullOrEmpty(Args.OutputModelFile) ? Host.CreateOutputFile(Args.OutputModelFile) : Host.CreateTempFile(".zip")) { TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); ch.Trace("Constructing the testing pipeline"); using (var stream = file.OpenReadStream()) using (var rep = RepositoryReader.Open(stream, ch)) testPipe = LoadLoader(rep, Args.TestFile, true); } // Score. ch.Trace("Scoring and evaluating"); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); // Evaluate. var evalComp = Args.Evaluator; if (!evalComp.IsGood()) { evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema); } var evaluator = evalComp.CreateInstance(Host); var dataEval = new RoleMappedData(scorePipe, label, features, group, weight, name, customCols, opt: true); var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) { throw ch.Except("No overall metrics found"); } overall = evaluator.GetOverallResults(overall); MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary <string, IDataView>[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(dataEval); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); } }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = Args.Trainer.CreateComponent(Host); IPredictor inputPredictor = null; if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing the training pipeline"); IDataView trainPipe = CreateLoader(); ISchema schema = trainPipe.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), Args.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), Args.NameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols); RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) { if (!trainer.Info.SupportsValidation) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } // In addition to the training set, some trainers can accept two data sets, validation set and test set, // in training phase. The major difference between validation set and test set is that training process may // indirectly use validation set to improve the model but the learned model should totally independent of test set. // Similar to validation set, the trainer can report the scores computed using test set. RoleMappedData testDataUsedInTrainer = null; if (!string.IsNullOrWhiteSpace(Args.TestFile)) { // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided // because this is TrainTest command. if (trainer.Info.SupportsTest) { ch.Trace("Constructing the test pipeline"); IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: Args.TestFile); testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer); testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, validData, Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer); IDataLoader testPipe; bool hasOutfile = !string.IsNullOrEmpty(Args.OutputModelFile); var tempFilePath = hasOutfile ? null : Path.GetTempFileName(); using (var file = new SimpleFileHandle(ch, hasOutfile ? Args.OutputModelFile : tempFilePath, true, !hasOutfile)) { TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); ch.Trace("Constructing the testing pipeline"); using (var stream = file.OpenReadStream()) using (var rep = RepositoryReader.Open(stream, ch)) testPipe = LoadLoader(rep, Args.TestFile, true); } // Score. ch.Trace("Scoring and evaluating"); ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line."); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); // Evaluate. var evaluator = Args.Evaluator?.CreateComponent(Host) ?? EvaluateUtils.GetEvaluator(Host, scorePipe.Schema); var dataEval = new RoleMappedData(scorePipe, label, features, group, weight, name, customCols, opt: true); var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) { throw ch.Except("No overall metrics found"); } overall = evaluator.GetOverallResults(overall); MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary <string, IDataView>[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(dataEval); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); } }
private void RunCore(IChannel ch) { ch.Trace("Constructing data pipeline"); IDataLoader loader; IPredictor predictor; RoleMappedSchema trainSchema; LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader); ch.AssertValue(predictor); ch.AssertValueOrNull(trainSchema); ch.AssertValue(loader); ch.Trace("Binding columns"); ISchema schema = loader.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.WeightColumn), Args.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); // Score. ch.Trace("Scoring and evaluating"); ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TestCommand should only be used from the command line."); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, loader, features, group, customCols, Host, trainSchema); // Evaluate. var evalComp = Args.Evaluator; if (!evalComp.IsGood()) { evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema); } var evaluator = evalComp.CreateInstance(Host); var data = new RoleMappedData(scorePipe, label, null, group, weight, name, customCols); var metrics = evaluator.Evaluate(data); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) { throw ch.Except("No overall metrics found"); } overall = evaluator.GetOverallResults(overall); MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary <string, IDataView>[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(data); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); } }