public static IDataScorerTransform GetScorer(SubComponent <IDataScorerTransform, SignatureDataScorer> scorer, IPredictor predictor, IDataView input, string featureColName, string groupColName, IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > customColumns, IHostEnvironment env, RoleMappedSchema trainSchema) { Contracts.CheckValue(env, nameof(env)); env.CheckValueOrNull(scorer); env.CheckValue(predictor, nameof(predictor)); env.CheckValue(input, nameof(input)); env.CheckValueOrNull(featureColName); env.CheckValueOrNull(groupColName); env.CheckValueOrNull(customColumns); env.CheckValueOrNull(trainSchema); var schema = TrainUtils.CreateRoleMappedSchemaOpt(input.Schema, featureColName, groupColName, customColumns); ISchemaBoundMapper mapper; var sc = GetScorerComponentAndMapper(predictor, scorer, schema, env, out mapper); return(sc.CreateInstance(env, input, mapper, trainSchema)); }
private void RunCore(IChannel ch) { Host.AssertValue(ch); ch.Trace("Creating loader"); IPredictor predictor; IDataLoader loader; RoleMappedSchema trainSchema; LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader); ch.AssertValue(predictor); ch.AssertValueOrNull(trainSchema); ch.AssertValue(loader); ch.Trace("Creating pipeline"); var scorer = Args.Scorer; var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorer); ch.AssertValue(bindable); // REVIEW: We probably ought to prefer role mappings from the training schema. string feat = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); var schema = TrainUtils.CreateRoleMappedSchemaOpt(loader.Schema, feat, group, customCols); var mapper = bindable.Bind(Host, schema); if (!scorer.IsGood()) { scorer = ScoreUtils.GetScorerComponent(mapper); } loader = CompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(), (env, view) => scorer.CreateInstance(env, view, mapper, trainSchema)); loader = CompositeDataLoader.Create(Host, loader, Args.PostTransform); if (!string.IsNullOrWhiteSpace(Args.OutputModelFile)) { ch.Trace("Saving the data pipe"); SaveLoader(loader, Args.OutputModelFile); } ch.Trace("Creating saver"); var saver = Args.Saver; if (!saver.IsGood()) { var ext = Path.GetExtension(Args.OutputDataFile); var isText = ext == ".txt" || ext == ".tlc"; saver = new SubComponent <IDataSaver, SignatureDataSaver>(isText ? "TextSaver" : "BinarySaver"); } var writer = saver.CreateInstance(Host); ch.Assert(writer != null); var outputIsBinary = writer is BinaryWriter; bool outputAllColumns = Args.OutputAllColumns == true || (Args.OutputAllColumns == null && Utils.Size(Args.OutputColumn) == 0 && outputIsBinary); bool outputNamesAndLabels = Args.OutputAllColumns == true || Utils.Size(Args.OutputColumn) == 0; if (Args.OutputAllColumns == true && Utils.Size(Args.OutputColumn) != 0) { ch.Warning("outputAllColumns=+ always writes all columns irrespective of outputColumn specified."); } if (!outputAllColumns && Utils.Size(Args.OutputColumn) != 0) { foreach (var outCol in Args.OutputColumn) { int dummyColIndex; if (!loader.Schema.TryGetColumnIndex(outCol, out dummyColIndex)) { throw ch.ExceptUserArg(nameof(Arguments.OutputColumn), "Column '{0}' not found.", outCol); } } } int colMax; uint maxScoreId = 0; if (!outputAllColumns) { maxScoreId = loader.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId); } ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based var cols = new List <int>(); for (int i = 0; i < loader.Schema.ColumnCount; i++) { if (!Args.KeepHidden && loader.Schema.IsHidden(i)) { continue; } if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels))) { continue; } var type = loader.Schema.GetColumnType(i); if (writer.IsColumnSavable(type)) { cols.Add(i); } else { ch.Warning("The column '{0}' will not be written as it has unsavable column type.", loader.Schema.GetColumnName(i)); } } ch.Check(cols.Count > 0, "No valid columns to save"); ch.Trace("Scoring and saving data"); using (var file = Host.CreateOutputFile(Args.OutputDataFile)) using (var stream = file.CreateWriteStream()) writer.SaveData(stream, loader, cols.ToArray()); }