/// <summary> /// Attempt to apply the data transform to a different data view source. /// If the transform in question implements <see cref="ITransformTemplate"/>, <see cref="ITransformTemplate.ApplyToData"/> /// is called. Otherwise, the transform is serialized into a byte array and then deserialized. /// </summary> /// <param name="env">The host to use</param> /// <param name="transform">The transform to apply.</param> /// <param name="newSource">The data view to apply the transform to.</param> /// <returns>The resulting data view.</returns> public static IDataTransform ApplyTransformToData(IHostEnvironment env, IDataTransform transform, IDataView newSource) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(transform, nameof(transform)); env.CheckValue(newSource, nameof(newSource)); var rebindable = transform as ITransformTemplate; if (rebindable != null) { return(rebindable.ApplyToData(env, newSource)); } // Revert to serialization. using (var stream = new MemoryStream()) { using (var rep = RepositoryWriter.CreateNew(stream, env)) { ModelSaveContext.SaveModel(rep, transform, "model"); rep.Commit(); } stream.Position = 0; using (var rep = RepositoryReader.Open(stream, env)) { IDataTransform newData; ModelLoadContext.LoadModel <IDataTransform, SignatureLoadDataTransform>(env, out newData, rep, "model", newSource); return(newData); } } }
/// <summary> /// Train a model on a single example, /// </summary> /// <typeparam name="TOutput"></typeparam> /// <param name="trainerMaker"></param> /// <param name="checker"></param> private static void TrivialHelper <TOutput>(Func <ITrainerHost, ITrainer <Instances, IPredictorProducing <TOutput> > > trainerMaker, Action <TOutput, TOutput> checker) { // The following simple instance should result in a "trivial" predictor for binary classification, regression, and multiclass, I think. ListInstances instances = new ListInstances(); instances.AddInst(new Float[] { (Float)0.0 }, (Float)0); instances.CopyMetadata(null); ITrainerHost host = new TrainHost(new Random(1), 0); var trainer = trainerMaker(host); trainer.Train(instances); IPredictor <Instance, TOutput> predictor = (IPredictor <Instance, TOutput>)trainer.CreatePredictor(); IPredictor <Instance, TOutput> loadedPredictor = default(IPredictor <Instance, TOutput>); using (Stream stream = new MemoryStream()) { using (RepositoryWriter writer = RepositoryWriter.CreateNew(stream, false)) { ModelSaveContext.SaveModel(writer, predictor, "foo"); writer.Commit(); } stream.Position = 0; using (RepositoryReader reader = RepositoryReader.Open(stream, false)) { ModelLoadContext.LoadModel(out loadedPredictor, reader, "foo"); } Assert.AreNotEqual(default(IPredictor <Instance, TOutput>), loadedPredictor, "did not load expected model"); } TOutput result = predictor.Predict(instances[0]); TOutput loadedResult = loadedPredictor.Predict(instances[0]); checker(result, loadedResult); }
public static void LoadModel(IHostEnvironment env, Stream modelStream, bool loadNames, out IPredictor predictor, out RoleMappedSchema schema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(modelStream, nameof(modelStream)); schema = null; using (var rep = RepositoryReader.Open(modelStream, env)) { ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(env, out predictor, rep, ModelFileUtils.DirPredictor); if (loadNames) { var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, rep); if (roles != null) { var emptyView = ModelFileUtils.LoadPipeline(env, rep, new MultiFileSource(null)); schema = new RoleMappedSchema(emptyView.Schema, roles, opt: true); } else { FeatureNameCollection names; if (ModelFileUtils.TryLoadFeatureNames(out names, rep)) { schema = names.Schema; } } } } }
private ILegacyDataLoader LoadTransformChain(ILegacyDataLoader srcData) { Host.Assert(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile)); using (var file = Host.OpenInputFile(ImplOptions.InputModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, Host)) using (var pipeLoaderEntry = rep.OpenEntry(ModelFileUtils.DirDataLoaderModel, ModelLoadContext.ModelStreamName)) using (var ctx = new ModelLoadContext(rep, pipeLoaderEntry, ModelFileUtils.DirDataLoaderModel)) return(LegacyCompositeDataLoader.Create(Host, ctx, srcData, x => true)); }
private ILegacyDataLoader CreateLoaderFromBytes(byte[] loaderBytes, IMultiStreamSource files) { Contracts.CheckValue(loaderBytes, nameof(loaderBytes)); Contracts.CheckValue(files, nameof(files)); using (var stream = new MemoryStream(loaderBytes)) using (var rep = RepositoryReader.Open(stream, _host)) { return(ModelFileUtils.LoadLoader(_host, rep, files, false)); } }
private static IDataView LoadTransforms(IHostEnvironment env, IDataView input, string modelFile) { var view = input; using (var file = env.OpenInputFile(modelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, env)) { view = ModelFileUtils.LoadTransforms(env, view, rep); } return view; }
protected ILegacyDataLoader CreateRawLoader( Func <IHostEnvironment, IMultiStreamSource, ILegacyDataLoader> defaultLoaderFactory = null, string dataFile = null) { if (string.IsNullOrWhiteSpace(dataFile)) { dataFile = ImplOptions.DataFile; } ILegacyDataLoader loader; if (!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile) && ImplOptions.Loader == null) { // Load the loader from the data model. using (var file = Host.OpenInputFile(ImplOptions.InputModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, Host)) loader = LoadLoader(rep, dataFile, ImplOptions.LoadTransforms ?? true); } else { // Either there is no input model file, or there is, but the loader is overridden. IMultiStreamSource fileSource = new MultiFileSource(dataFile); var loaderFactory = ImplOptions.Loader; if (loaderFactory == null) { var ext = Path.GetExtension(dataFile); var isText = string.Equals(ext, ".txt", StringComparison.OrdinalIgnoreCase) || string.Equals(ext, ".tlc", StringComparison.OrdinalIgnoreCase); var isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase); var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase); return(isText ? TextLoader.Create(Host, new TextLoader.Options(), fileSource) : isBinary ? new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource) : isTranspose ? new TransposeLoader(Host, new TransposeLoader.Arguments(), fileSource) : defaultLoaderFactory != null?defaultLoaderFactory(Host, fileSource) : TextLoader.Create(Host, new TextLoader.Options(), fileSource)); } else { loader = loaderFactory.CreateComponent(Host, fileSource); } if (ImplOptions.LoadTransforms == true) { Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile)); loader = LoadTransformChain(loader); } } return(loader); }
/// <summary> /// Load a <see cref="ICanForecast{T}"/> model. /// </summary> /// <typeparam name="T">The type of <see cref="ICanForecast{T}"/>, usually float.</typeparam> /// <param name="catalog"><see cref="ModelOperationsCatalog"/></param> /// <param name="filePath">File path to load the model from.</param> /// <returns><see cref="ICanForecast{T}"/> model.</returns> public static ICanForecast <T> LoadForecastingModel <T>(this ModelOperationsCatalog catalog, string filePath) { var env = CatalogUtils.GetEnvironment(catalog); using (var file = File.OpenRead(filePath)) { using (var rep = RepositoryReader.Open(file, env)) { ModelLoadContext.LoadModel <ICanForecast <T>, SignatureLoadModel>(env, out var model, rep, LoaderSignature); return(model); } } }
public void LoadOriginalBinaryLoaderModel() { var env = new MLContext().AddStandardComponents(); using (var modelStream = File.OpenRead(Path.Combine("TestModels", "BinaryLoader-v3.11.0.0.zip"))) using (var rep = RepositoryReader.Open(modelStream, env)) { IDataLoader result = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(null), true); Assert.Equal(2, result.Schema.Count); Assert.Equal("Image", result.Schema[0].Name); Assert.Equal("Class", result.Schema[1].Name); } }
public void TestTextLoaderKeyTypeBackCompat() { // Model generated with the following command on a version of the code previous to the KeyType change that removed Min and Contiguous: // Train data=...\breast-cancer.txt loader =TextLoader{col=Label:R4:0 col=Features:R4:1-9 col=key:U4[0-*]:3} tr=LogisticRegression {} out=model.zip var mlContext = new MLContext(1); string textLoaderModelPath = GetDataPath("backcompat/textloader-with-key-model.zip"); string breastCancerPath = GetDataPath(TestDatasets.breastCancer.trainFilename); using (FileStream modelfs = File.OpenRead(textLoaderModelPath)) using (var rep = RepositoryReader.Open(modelfs, mlContext)) { var result = ModelFileUtils.LoadLoader(mlContext, rep, new MultiFileSource(breastCancerPath), false); Assert.True(result.Schema.TryGetColumnIndex("key", out int featureIdx)); Assert.True(result.Schema[featureIdx].Type is KeyDataViewType keyType && keyType.Count == typeof(uint).ToMaxInt()); } }
public void LoadOldConcatTransformModel() { var env = new MLContext().AddStandardComponents(); using (var modelStream = File.OpenRead(Path.Combine("TestModels", "ConcatTransform.zip"))) using (var rep = RepositoryReader.Open(modelStream, env)) { var result = ModelFileUtils.LoadPipeline(env, rep, new MultiFileSource(null), true); Assert.Equal(3, result.Schema.Count); Assert.Equal("Label", result.Schema[0].Name); Assert.Equal("Features", result.Schema[1].Name); Assert.Equal("Features", result.Schema[2].Name); Assert.Equal(9, (result.Schema[1].Type as VectorType)?.Size); Assert.Equal(18, (result.Schema[2].Type as VectorType)?.Size); } }
private IDataScorerTransform GetScorer(IHostEnvironment env, IDataView transforms, IPredictor pred, string testDataPath = null) { using (var ch = env.Start("Saving model")) using (var memoryStream = new MemoryStream()) { var trainRoles = new RoleMappedData(transforms, label: "Label", feature: "Features"); // Model cannot be saved with CacheDataView TrainUtils.SaveModel(env, ch, memoryStream, pred, trainRoles); memoryStream.Position = 0; using (var rep = RepositoryReader.Open(memoryStream, ch)) { IDataLoader testPipe = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(testDataPath), true); RoleMappedData testRoles = new RoleMappedData(testPipe, label: "Label", feature: "Features"); return(ScoreUtils.GetScorer(pred, testRoles, env, testRoles.Schema)); } } }
protected IDataLoader CreateRawLoader(string defaultLoader = "TextLoader", string dataFile = null) { if (string.IsNullOrWhiteSpace(dataFile)) { dataFile = Args.DataFile; } IDataLoader loader; if (!string.IsNullOrWhiteSpace(Args.InputModelFile) && !Args.Loader.IsGood()) { // Load the loader from the data model. using (var file = Host.OpenInputFile(Args.InputModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, Host)) loader = LoadLoader(rep, dataFile, Args.LoadTransforms ?? true); } else { // Either there is no input model file, or there is, but the loader is overridden. var sub = Args.Loader; if (!sub.IsGood()) { var ext = Path.GetExtension(dataFile); var isText = string.Equals(ext, ".txt", StringComparison.OrdinalIgnoreCase) || string.Equals(ext, ".tlc", StringComparison.OrdinalIgnoreCase); var isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase); var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase); sub = new SubComponent <IDataLoader, SignatureDataLoader>( isText ? "TextLoader" : isBinary ? "BinaryLoader" : isTranspose ? "TransposeLoader" : defaultLoader); } loader = sub.CreateInstance(Host, new MultiFileSource(dataFile)); if (Args.LoadTransforms == true) { Host.CheckUserArg(!string.IsNullOrWhiteSpace(Args.InputModelFile), nameof(Args.InputModelFile)); loader = LoadTransformChain(loader); } } return(loader); }
public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string inputModelFile, out IPredictor inputPredictor) { Contracts.AssertValue(env); Contracts.AssertValue(ch); if (!string.IsNullOrEmpty(inputModelFile)) { ch.Trace("Constructing predictor from input model"); using (var file = env.OpenInputFile(inputModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, ch)) { ch.Trace("Loading predictor"); return(ModelLoadContext.LoadModelOrNull <IPredictor, SignatureLoadModel>(env, out inputPredictor, rep, ModelFileUtils.DirPredictor)); } } inputPredictor = null; return(false); }
public void TestTextLoaderBackCompat_VerWritt_0x0001000C() { // Checks backward compatibility with a text loader created with "verWrittenCur: 0x0001000C" // Model generated with: // loader=text{header+ col=SepalLength:Num:0 col=SepalWidth:Num:1 col=PetalLength:Num:2 col=PetalWidth:Num:2 col=Cat:TX:1-8 col=Num:9-14 col=Type:TX:4} var mlContext = new MLContext(1); string textLoaderModelPath = GetDataPath("backcompat/textloader_VerWritt_0x0001000C.zip"); string irisPath = GetDataPath(TestDatasets.irisData.trainFilename); IDataView iris; using (FileStream modelfs = File.OpenRead(textLoaderModelPath)) using (var rep = RepositoryReader.Open(modelfs, mlContext)) { iris = ModelFileUtils.LoadLoader(mlContext, rep, new MultiFileSource(irisPath), false); } var previewIris = iris.Preview(1); var irisFirstRow = new Dictionary <string, float>(); irisFirstRow["SepalLength"] = 5.1f; irisFirstRow["SepalWidth"] = 3.5f; irisFirstRow["PetalLength"] = 1.4f; irisFirstRow["PetalWidth"] = 0.2f; Assert.Equal(5, previewIris.ColumnView.Length); Assert.Equal("SepalLength", previewIris.Schema[0].Name); Assert.Equal(NumberDataViewType.Single, previewIris.Schema[0].Type); int index = 0; foreach (var entry in irisFirstRow) { Assert.Equal(entry.Key, previewIris.RowView[0].Values[index].Key); Assert.Equal(entry.Value, previewIris.RowView[0].Values[index++].Value); } Assert.Equal("Type", previewIris.RowView[0].Values[index].Key); Assert.Equal("Iris-setosa", previewIris.RowView[0].Values[index].Value.ToString()); }
public LoaderWrapper(IHostEnvironment env, ModelLoadContext ctx) { ctx.CheckAtModel(GetVersionInfo()); ctx.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null)); var loaderStream = new MemoryStream(); using (var rep = RepositoryWriter.CreateNew(loaderStream)) { ModelSaveContext.SaveModel(rep, loader, "Loader"); rep.Commit(); } _env = env; _loaderFactory = (IMultiStreamSource source) => { using (var rep = RepositoryReader.Open(loaderStream)) { ModelLoadContext.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out var ldr, rep, "Loader", source); return(ldr); } }; }
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckUserArg(!string.IsNullOrWhiteSpace(args.InputModelFile), nameof(args.InputModelFile), "The input model file is required."); IPredictor predictor; RoleMappedSchema trainSchema = null; using (var file = env.OpenInputFile(args.InputModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, env)) { ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(env, out predictor, rep, ModelFileUtils.DirPredictor); trainSchema = ModelFileUtils.LoadRoleMappedSchemaOrNull(env, rep); } string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema, nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema, nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); return(ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, trainSchema)); }
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); var h = env.Register("LoadTransform"); h.CheckValue(args, nameof(args)); h.CheckValue(input, nameof(input)); h.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile), "File does not exist"); IDataView currentView; // If there are no 'tag' parameters, we load everything, regardless of 'comp'. bool complement = args.Complement || Utils.Size(args.Tag) == 0; var allTags = new HashSet <string>(); for (int i = 0; i < Utils.Size(args.Tag); i++) { var curList = args.Tag[i]; if (string.IsNullOrWhiteSpace(curList)) { continue; } foreach (var tag in curList.Split(',')) { if (!string.IsNullOrWhiteSpace(tag)) { allTags.Add(tag.ToLower()); } } } Func <string, bool> predicate = tag => { bool found = allTags.Contains(tag.ToLower()); return(found == !complement); }; using (var file = h.OpenInputFile(args.ModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, h)) using (var pipeLoaderEntry = rep.OpenEntry(ModelFileUtils.DirDataLoaderModel, ModelLoadContext.ModelStreamName)) using (var ctx = new ModelLoadContext(rep, pipeLoaderEntry, ModelFileUtils.DirDataLoaderModel)) { currentView = CompositeDataLoader.LoadSelectedTransforms(ctx, input, h, predicate); if (currentView == input) { // REVIEW: we are required to return an IDataTransform. Therefore, if we don't introduce a new transform // on top of 'input', we must throw (since input may not be a data transform). // We could of course introduce a 'no-op transform', or we could lift the requirement to always return an IDataTransform // associated with SignatureDataTransform. var criteria = string.Format( complement ? "transforms that don't have tags from the list: '{0}'" : "transforms that have tags from the list: '{0}'", string.Join(",", allTags)); throw h.ExceptUserArg(nameof(args.Tag), "No transforms were found that match the search criteria ({0})", criteria); } } h.Assert(currentView is IDataTransform); return((IDataTransform)currentView); }
// Factory method for SignatureDataTransform. private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("Tree Featurizer Transform"); host.CheckValue(args, nameof(args)); host.CheckValue(input, nameof(input)); host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile), "Please specify either a trainer or an input model file."); host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column"); IDataTransform xf; using (var ch = host.Start("Create Tree Ensemble Scorer")) { var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix }; if (!string.IsNullOrWhiteSpace(args.TrainedModelFile)) { if (args.Trainer != null) { ch.Warning("Both an input model and a trainer were specified. Using the model file."); } ch.Trace("Loading model"); IPredictor predictor; using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read)) using (var rep = RepositoryReader.Open(strm, ch)) ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor); ch.Trace("Creating scorer"); var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args); Contracts.Assert(data.Schema.Feature.HasValue); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) { predictor = ((CalibratedPredictorBase)predictor).SubPredictor; } // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type"); if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.TrainedModelFile), "Predictor in model file expects {0} features, but data has {1} features", vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize); } ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema); } else { ch.AssertValue(args.Trainer); ch.Trace("Creating TrainAndScoreTransform"); var trainScoreArgs = new TrainAndScoreTransformer.Arguments(); args.CopyTo(trainScoreArgs); trainScoreArgs.Trainer = args.Trainer; trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>( (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema)); var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>( (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor)); var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed); var scoreXf = TrainAndScoreTransformer.Create(host, trainScoreArgs, labelInput, mapperFactory); if (input == labelInput) { return(scoreXf); } return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput)); } } return(xf); }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = Args.Trainer.CreateInstance(Host); IPredictor inputPredictor = null; if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing the training pipeline"); IDataView trainPipe = CreateLoader(); ISchema schema = trainPipe.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), Args.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), Args.NameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn); var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols); RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(Args.ValidationFile)) { if (!TrainUtils.CanUseValidationData(trainer)) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData, Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor); IDataLoader testPipe; using (var file = !string.IsNullOrEmpty(Args.OutputModelFile) ? Host.CreateOutputFile(Args.OutputModelFile) : Host.CreateTempFile(".zip")) { TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); ch.Trace("Constructing the testing pipeline"); using (var stream = file.OpenReadStream()) using (var rep = RepositoryReader.Open(stream, ch)) testPipe = LoadLoader(rep, Args.TestFile, true); } // Score. ch.Trace("Scoring and evaluating"); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); // Evaluate. var evalComp = Args.Evaluator; if (!evalComp.IsGood()) { evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema); } var evaluator = evalComp.CreateInstance(Host); var dataEval = new RoleMappedData(scorePipe, label, features, group, weight, name, customCols, opt: true); var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) { throw ch.Except("No overall metrics found"); } overall = evaluator.GetOverallResults(overall); MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary <string, IDataView>[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(dataEval); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv); } }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); Host.AssertNonEmpty(cmd); ch.Trace("Constructing trainer"); ITrainer trainer = ImplOptions.Trainer.CreateComponent(Host); IPredictor inputPredictor = null; if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing the training pipeline"); IDataView trainPipe = CreateLoader(); var schema = trainPipe.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), ImplOptions.LabelColumn, DefaultColumnNames.Label); string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId); string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), ImplOptions.WeightColumn, DefaultColumnNames.Weight); string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, ImplOptions.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns); var data = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols); RoleMappedData validData = null; if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile)) { if (!trainer.Info.SupportsValidation) { ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset."); } else { ch.Trace("Constructing the validation pipeline"); IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile); validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe); validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames()); } } // In addition to the training set, some trainers can accept two data sets, validation set and test set, // in training phase. The major difference between validation set and test set is that training process may // indirectly use validation set to improve the model but the learned model should totally independent of test set. // Similar to validation set, the trainer can report the scores computed using test set. RoleMappedData testDataUsedInTrainer = null; if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile)) { // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided // because this is TrainTest command. if (trainer.Info.SupportsTest) { ch.Trace("Constructing the test pipeline"); IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile); testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer); testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames()); } } var predictor = TrainUtils.Train(Host, ch, data, trainer, validData, ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer); ILegacyDataLoader testPipe; bool hasOutfile = !string.IsNullOrEmpty(ImplOptions.OutputModelFile); var tempFilePath = hasOutfile ? null : Path.GetTempFileName(); using (var file = new SimpleFileHandle(ch, hasOutfile ? ImplOptions.OutputModelFile : tempFilePath, true, !hasOutfile)) { TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); ch.Trace("Constructing the testing pipeline"); using (var stream = file.OpenReadStream()) using (var rep = RepositoryReader.Open(stream, ch)) testPipe = LoadLoader(rep, ImplOptions.TestFile, true); } // Score. ch.Trace("Scoring and evaluating"); ch.Assert(ImplOptions.Scorer == null || ImplOptions.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line."); IDataScorerTransform scorePipe = ScoreUtils.GetScorer(ImplOptions.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema); // Evaluate. var evaluator = ImplOptions.Evaluator?.CreateComponent(Host) ?? EvaluateUtils.GetEvaluator(Host, scorePipe.Schema); var dataEval = new RoleMappedData(scorePipe, label, features, group, weight, name, customCols, opt: true); var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) { throw ch.Except("No overall metrics found"); } overall = evaluator.GetOverallResults(overall); MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, 1); evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary <string, IDataView>[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(dataEval); var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols); var idv = evaluator.GetPerInstanceDataViewToSave(perInstData); MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, idv); } }
/// <summary> /// Loads multiple artifacts of interest from the input model file, given the context /// established by the command line arguments. /// </summary> /// <param name="ch">The channel to which to provide output.</param> /// <param name="wantPredictor">Whether we want a predictor from the model file. If /// <c>false</c> we will not even attempt to load a predictor. If <c>null</c> we will /// load the predictor, if present. If <c>true</c> we will load the predictor, or fail /// noisily if we cannot.</param> /// <param name="predictor">The predictor in the model, or <c>null</c> if /// <paramref name="wantPredictor"/> was false, or <paramref name="wantPredictor"/> was /// <c>null</c> and no predictor was present.</param> /// <param name="wantTrainSchema">Whether we want the training schema. Unlike /// <paramref name="wantPredictor"/>, this has no "hard fail if not present" option. If /// this is <c>true</c>, it is still possible for <paramref name="trainSchema"/> to remain /// <c>null</c> if there were no role mappings, or pipeline.</param> /// <param name="trainSchema">The training schema if <paramref name="wantTrainSchema"/> /// is true, and there were role mappings stored in the model.</param> /// <param name="pipe">The data pipe constructed from the combination of the /// model and command line arguments.</param> protected void LoadModelObjects( IChannel ch, bool?wantPredictor, out IPredictor predictor, bool wantTrainSchema, out RoleMappedSchema trainSchema, out ILegacyDataLoader pipe) { // First handle the case where there is no input model file. // Everything must come from the command line. using (var file = Host.OpenInputFile(ImplOptions.InputModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, Host)) { // First consider loading the predictor. if (wantPredictor == false) { predictor = null; } else { ch.Trace("Loading predictor"); predictor = ModelFileUtils.LoadPredictorOrNull(Host, rep); if (wantPredictor == true) { Host.Check(predictor != null, "Could not load predictor from model file"); } } // Next create the loader. var loaderFactory = ImplOptions.Loader; ILegacyDataLoader trainPipe = null; if (loaderFactory != null) { // The loader is overridden from the command line. pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(ImplOptions.DataFile)); if (ImplOptions.LoadTransforms == true) { Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile)); pipe = LoadTransformChain(pipe); } } else { var loadTrans = ImplOptions.LoadTransforms ?? true; pipe = LoadLoader(rep, ImplOptions.DataFile, loadTrans); if (loadTrans) { trainPipe = pipe; } } if (Utils.Size(ImplOptions.Transforms) > 0) { pipe = LegacyCompositeDataLoader.Create(Host, pipe, ImplOptions.Transforms); } // Next consider loading the training data's role mapped schema. trainSchema = null; if (wantTrainSchema) { // First try to get the role mappings. var trainRoleMappings = ModelFileUtils.LoadRoleMappingsOrNull(Host, rep); if (trainRoleMappings != null) { // Next create the training schema. In the event that the loaded pipeline happens // to be the training pipe, we can just use that. If it differs, then we need to // load the full pipeline from the model, relying upon the fact that all loaders // can be loaded with no data at all, to get their schemas. if (trainPipe == null) { trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true); } trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings); } // If the role mappings are null, an alternative would be to fail. However the idea // is that the scorer should always still succeed, although perhaps with reduced // functionality, even when the training schema is null, since not all versions of // TLC models will have the role mappings preserved, I believe. And, we do want to // maintain backwards compatibility. } } }