/// <summary> /// Create a composite schema of both the partitioned columns and the underlying loader columns. /// </summary> /// <param name="ectx">The exception context.</param> /// <param name="cols">The partitioned columns.</param> /// <param name="subLoader">The sub loader.</param> /// <returns>The resulting schema.</returns> private DataViewSchema CreateSchema(IExceptionContext ectx, Column[] cols, ILegacyDataLoader subLoader) { Contracts.AssertValue(cols); Contracts.AssertValue(subLoader); var builder = new DataViewSchema.Builder(); builder.AddColumns(cols.Select(c => new DataViewSchema.DetachedColumn(c.Name, ColumnTypeExtensions.PrimitiveTypeFromKind(c.Type.Value), null))); var colSchema = builder.ToSchema(); var subSchema = subLoader.Schema; if (subSchema.Count == 0) { return(colSchema); } else { var schemas = new DataViewSchema[] { subSchema, colSchema }; return(new ZipBinding(schemas).OutputSchema); } }
private void RunCore(IChannel ch) { Host.AssertValue(ch, "ch"); IDataSaver saver; if (ImplOptions.Saver == null) { var ext = Path.GetExtension(ImplOptions.OutputDataFile); var isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase); var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase); if (isBinary) { saver = new BinarySaver(Host, new BinarySaver.Arguments()); } else if (isTranspose) { saver = new TransposeSaver(Host, new TransposeSaver.Arguments()); } else { saver = new TextSaver(Host, new TextSaver.Arguments()); } } else { saver = ImplOptions.Saver.CreateComponent(Host); } ILegacyDataLoader loader = CreateAndSaveLoader(); using (var file = Host.CreateOutputFile(ImplOptions.OutputDataFile)) DataSaverUtils.SaveDataView(ch, saver, loader, file, ImplOptions.KeepHidden); }
private LegacyCompositeDataLoader(IHost host, TransformEx[] transforms) { Contracts.AssertValue(host, "host"); _host = host; _host.AssertNonEmpty(transforms); View = transforms[transforms.Length - 1].Transform; _tview = View as ITransposeDataView; var srcLoader = transforms[0].Transform.Source as ILegacyDataLoader; #if DEBUG // Assert that the transforms array is consistent: first one starts with loader, // they are chained together, the loader is not a composite. for (int i = 1; i < transforms.Length; i++) { _host.Assert(transforms[i].Transform.Source == transforms[i - 1].Transform, "Transforms are not linked"); } _host.AssertValue(srcLoader, "loader", "Transform chain doesn't start with a loader"); _host.Assert(!(srcLoader is LegacyCompositeDataLoader), "Can't have composite source loader"); #endif _loader = srcLoader; _transforms = transforms; }
/// <summary> /// Loads all transforms from the <paramref name="ctx"/> that pass the <paramref name="isTransformTagAccepted"/> test, /// applies them sequentially to the <paramref name="srcLoader"/>, and returns the (composite) data loader. /// </summary> private static ILegacyDataLoader LoadTransforms(ModelLoadContext ctx, ILegacyDataLoader srcLoader, IHost host, Func <string, bool> isTransformTagAccepted) { Contracts.AssertValue(host, "host"); host.AssertValue(srcLoader); host.AssertValue(ctx); // *** Binary format *** // int: sizeof(Float) // int: number of transforms // foreach transform: (starting from version VersionAddedTags) // string: tag // string: args string int cbFloat = ctx.Reader.ReadInt32(); host.CheckDecode(cbFloat == sizeof(float)); int cxf = ctx.Reader.ReadInt32(); host.CheckDecode(cxf >= 0); bool hasTags = ctx.Header.ModelVerReadable >= VersionAddedTags; var tagData = new List <KeyValuePair <string, string> >(); var acceptedIds = new List <int>(); for (int i = 0; i < cxf; i++) { string tag = ""; string argsString = null; if (hasTags) { tag = ctx.LoadNonEmptyString(); argsString = ctx.LoadStringOrNull(); } if (!isTransformTagAccepted(tag)) { continue; } acceptedIds.Add(i); tagData.Add(new KeyValuePair <string, string>(tag, argsString)); } host.Assert(tagData.Count == acceptedIds.Count); if (tagData.Count == 0) { return(srcLoader); } return(ApplyTransformsCore(host, srcLoader, tagData.ToArray(), (h, index, data) => { IDataTransform xf; ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(host, out xf, string.Format(TransformDirTemplate, acceptedIds[index]), data); return xf; })); }
/// <param name="env">The environment.</param> /// <param name="registrationName">The registration name.</param> /// <param name="inputDataView">The input data view.</param> /// <param name="splitColumn">The column to use for splitting data into folds.</param> /// <param name="args">Cross validation arguments.</param> /// <param name="createExamples">The delegate to create RoleMappedData</param> /// <param name="applyTransformsToTestData">The delegate to apply the transforms from the train pipeline to the test data</param> /// <param name="scorer">The scorer</param> /// <param name="evaluator">The evaluator</param> /// <param name="getValidationDataView">The delegate to create validation data view</param> /// <param name="applyTransformsToValidationData">The delegate to apply the transforms from the train pipeline to the validation data</param> /// <param name="inputPredictor">The input predictor, for the continue training option</param> /// <param name="cmd">The command string.</param> /// <param name="loader">Original loader so we can construct correct pipeline for model saving.</param> /// <param name="savePerInstance">Whether to produce the per-instance data view.</param> /// <returns></returns> public FoldHelper( IHostEnvironment env, string registrationName, IDataView inputDataView, string splitColumn, Arguments args, Func <IHostEnvironment, IChannel, IDataView, ITrainer, RoleMappedData> createExamples, Func <IHostEnvironment, IChannel, IDataView, RoleMappedData, IDataView, RoleMappedData> applyTransformsToTestData, IComponentFactory <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> scorer, IComponentFactory <IMamlEvaluator> evaluator, Func <IDataView> getValidationDataView = null, Func <IHostEnvironment, IChannel, IDataView, RoleMappedData, IDataView, RoleMappedData> applyTransformsToValidationData = null, IPredictor inputPredictor = null, string cmd = null, ILegacyDataLoader loader = null, bool savePerInstance = false) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(registrationName, nameof(registrationName)); env.CheckValue(inputDataView, nameof(inputDataView)); env.CheckValue(splitColumn, nameof(splitColumn)); env.CheckParam(args.NumFolds > 1, nameof(args.NumFolds)); env.CheckValue(createExamples, nameof(createExamples)); env.CheckValue(applyTransformsToTestData, nameof(applyTransformsToTestData)); env.CheckValue(args.Trainer, nameof(args.Trainer)); env.CheckValueOrNull(scorer); env.CheckValueOrNull(evaluator); env.CheckValueOrNull(args.Calibrator); env.CheckParam(args.MaxCalibrationExamples > 0, nameof(args.MaxCalibrationExamples)); env.CheckParam(getValidationDataView == null || applyTransformsToValidationData != null, nameof(applyTransformsToValidationData)); env.CheckValueOrNull(inputPredictor); env.CheckValueOrNull(cmd); env.CheckValueOrNull(args.OutputModelFile); env.CheckValueOrNull(loader); _env = env; _registrationName = registrationName; _inputDataView = inputDataView; _splitColumn = splitColumn; _numFolds = args.NumFolds; _createExamples = createExamples; _applyTransformsToTestData = applyTransformsToTestData; _trainer = args.Trainer; _scorer = scorer; _evaluator = evaluator; _calibrator = args.Calibrator; _maxCalibrationExamples = args.MaxCalibrationExamples; _useThreads = args.UseThreads; _cacheData = args.CacheData; _getValidationDataView = getValidationDataView; _applyTransformsToValidationData = applyTransformsToValidationData; _inputPredictor = inputPredictor; _cmd = cmd; _outputModelFile = args.OutputModelFile; _loader = loader; _savePerInstance = savePerInstance; }
private byte[] SaveLoaderToBytes(ILegacyDataLoader loader) { Contracts.CheckValue(loader, nameof(loader)); using (var stream = new MemoryStream()) { LoaderUtils.SaveLoader(loader, stream); return(stream.GetBuffer()); } }
/// <summary> /// Creates a <see cref="LegacyCompositeDataLoader"/> that starts with the <paramref name="srcLoader"/>, /// and follows with transforms created from the <paramref name="transformArgs"/> array. /// If there are no transforms, the <paramref name="srcLoader"/> is returned. /// </summary> public static ILegacyDataLoader Create(IHostEnvironment env, ILegacyDataLoader srcLoader, params KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >[] transformArgs) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); h.CheckValue(srcLoader, nameof(srcLoader)); h.CheckValueOrNull(transformArgs); return(CreateCore(h, srcLoader, transformArgs)); }
private static ILegacyDataLoader CreateCore(IHost host, ILegacyDataLoader srcLoader, KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >[] transformArgs) { Contracts.AssertValue(host, "host"); host.AssertValue(srcLoader, "srcLoader"); host.AssertValueOrNull(transformArgs); if (Utils.Size(transformArgs) == 0) { return(srcLoader); }
private ILegacyDataLoader LoadTransformChain(ILegacyDataLoader srcData) { Host.Assert(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile)); using (var file = Host.OpenInputFile(ImplOptions.InputModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, Host)) using (var pipeLoaderEntry = rep.OpenEntry(ModelFileUtils.DirDataLoaderModel, ModelLoadContext.ModelStreamName)) using (var ctx = new ModelLoadContext(rep, pipeLoaderEntry, ModelFileUtils.DirDataLoaderModel)) return(LegacyCompositeDataLoader.Create(Host, ctx, srcData, x => true)); }
/// <summary> /// Saves <paramref name="loader"/> to the specified <paramref name="file"/>. /// </summary> public static void SaveLoader(ILegacyDataLoader loader, IFileHandle file) { Contracts.CheckValue(loader, nameof(loader)); Contracts.CheckValue(file, nameof(file)); Contracts.CheckParam(file.CanWrite, nameof(file), "Must be writable"); using (var stream = file.CreateWriteStream()) { SaveLoader(loader, stream); } }
private void RunCore(IChannel ch) { ILegacyDataLoader loader = CreateAndSaveLoader(); using (var schemaWriter = new StringWriter()) { RunOnData(schemaWriter, ImplOptions, loader); var str = schemaWriter.ToString(); ch.AssertNonEmpty(str); ch.Info(str); } }
/// <summary> /// Saves <paramref name="loader"/> to the specified <paramref name="stream"/>. /// </summary> public static void SaveLoader(ILegacyDataLoader loader, Stream stream) { Contracts.CheckValue(loader, nameof(loader)); Contracts.CheckValue(stream, nameof(stream)); Contracts.CheckParam(stream.CanWrite, nameof(stream), "Must be writable"); using (var rep = RepositoryWriter.CreateNew(stream)) { ModelSaveContext.SaveModel(rep, loader, ModelFileUtils.DirDataLoaderModel); rep.Commit(); } }
/// <summary> /// Appends transforms to the <paramref name="srcLoader"/> and returns a loader that contains these new transforms. /// If there are no transforms to append, returns <paramref name="srcLoader"/> intact, otherwise creates a /// <see cref="LegacyCompositeDataLoader"/>. The transforms are created by sequentially invoking the provided lambda, /// one time for each element of <paramref name="tagData"/>. /// </summary> /// <param name="env">The host environment.</param> /// <param name="srcLoader">The source loader.</param> /// <param name="tagData">The array of (tag, creationInfo) pairs. Can be an empty array or null, in which case /// the function returns <paramref name="srcLoader"/>.</param> /// <param name="createTransform">The delegate to invoke at each transform creation. /// Delegate parameters are: host environment, transform index (0 to <c>tagData.Length</c>), source data view. /// It should return the <see cref="IDataView"/> that should share the same loader as the source data view.</param> /// <returns>The resulting data loader.</returns> public static ILegacyDataLoader ApplyTransforms(IHostEnvironment env, ILegacyDataLoader srcLoader, KeyValuePair<string, string>[] tagData, Func<IHostEnvironment, int, IDataView, IDataView> createTransform) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); h.CheckValue(srcLoader, nameof(srcLoader)); h.CheckValueOrNull(tagData); h.CheckValue(createTransform, nameof(createTransform)); if (Utils.Size(tagData) == 0) return srcLoader; return ApplyTransformsCore(h, srcLoader, tagData, createTransform); }
/// <summary> /// Creates a <see cref="ILegacyDataLoader"/> from the specified source loader, followed by /// the transforms that are loaded from the <paramref name="ctx"/>, tags filtered by /// by the <paramref name="isTransformTagAccepted"/>. /// If the <paramref name="ctx"/> contains no accepted transforms, the <paramref name="srcLoader"/> is /// returned intact. /// </summary> public static ILegacyDataLoader Create(IHostEnvironment env, ModelLoadContext ctx, ILegacyDataLoader srcLoader, Func <string, bool> isTransformTagAccepted) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); h.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); h.CheckValue(srcLoader, nameof(srcLoader)); h.CheckValue(isTransformTagAccepted, nameof(isTransformTagAccepted)); return(LoadTransforms(ctx, srcLoader, h, isTransformTagAccepted)); }
/// <summary> /// Apply one transform to the data loader, and returns a (composite) data loader that contains the result. /// The transform is created by invoking the lambda for a data source, and it should return an /// <see cref="IDataView"/> that shares the same loader as the provided source. /// </summary> public static ILegacyDataLoader ApplyTransform(IHostEnvironment env, ILegacyDataLoader srcLoader, string tag, string creationArgs, Func<IHostEnvironment, IDataView, IDataView> createTransform) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); h.CheckValue(srcLoader, nameof(srcLoader)); h.CheckValueOrNull(tag); h.CheckValueOrNull(creationArgs); h.CheckValue(createTransform, nameof(createTransform)); var tagData = new[] { new KeyValuePair<string, string>(tag, creationArgs) }; return ApplyTransformsCore(env.Register(RegistrationName), srcLoader, tagData, (e, index, data) => createTransform(e, data)); }
protected override bool MoveNextCore() { // Iterate sub cursor or move to the next file. while (_subCursor == null || !_subCursor.MoveNext()) { // Cleanup old sub cursor if (_subCursor != null) { _subCursor.Dispose(); _subCursor = null; } if (!TryGetNextPathAndValues(out string path, out string relativePath, out List <string> values)) { return(false); } ILegacyDataLoader loader = null; try { // Load the sub cursor and reset the data. loader = _parent.CreateLoaderFromBytes(_parent._subLoaderBytes, new MultiFileSource(path)); } catch (Exception e) { Ch.Warning($"Failed to load file {path} due to a loader exception. Moving on to the next file. Ex: {e.Message}"); continue; } _subCursor = loader.GetRowCursor(_subActivecolumnsNeeded); try { UpdateSubGetters(); UpdateColumnValues(relativePath, values); } catch (InvalidOperationException e) { // Failed to load this file so skip. Ch.Warning(MessageSensitivity.Schema, e.Message); if (_subCursor != null) { _subCursor.Dispose(); _subCursor = null; } } } return(true); }
public void LoadBinaryLoaderModelVersion3() { var env = new MLContext(1).AddStandardComponents(); using (var modelStream = File.OpenRead(Path.Combine("TestModels", "BinaryLoader-v3.11.0.0.zip"))) using (var rep = RepositoryReader.Open(modelStream, env)) { ILegacyDataLoader result = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(null), true); Assert.Equal(2, result.Schema.Count); Assert.Equal("Image", result.Schema[0].Name); Assert.Equal("Class", result.Schema[1].Name); } }
private static ILegacyDataLoader CreateCore(IHost host, ILegacyDataLoader srcLoader, KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >[] transformArgs) { Contracts.AssertValue(host, "host"); host.AssertValue(srcLoader, "srcLoader"); host.AssertValueOrNull(transformArgs); if (Utils.Size(transformArgs) == 0) { return(srcLoader); } string GetTagData(IComponentFactory <IDataView, IDataTransform> factory) { // When coming from the command line, preserve the string arguments. // For other factories, we aren't able to get the string. return((factory as ICommandLineComponentFactory)?.ToString()); } var tagData = transformArgs .Select(x => new KeyValuePair <string, string>(x.Key, GetTagData(x.Value))) .ToArray(); // Warn if tags coincide with ones already present in the loader. var composite = srcLoader as LegacyCompositeDataLoader; if (composite != null) { using (var ch = host.Start("TagValidation")) { foreach (var pair in tagData) { if (!string.IsNullOrEmpty(pair.Key) && composite._transforms.Any(x => x.Tag == pair.Key)) { ch.Warning("The transform with tag '{0}' already exists in the chain", pair.Key); } } } } return(ApplyTransformsCore(host, srcLoader, tagData, (env, index, data) => transformArgs[index].Value.CreateComponent(env, data))); }
private void RunCore(IChannel ch, string cmd) { Host.AssertValue(ch); IPredictor inputPredictor = null; if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor)) { ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized."); } ch.Trace("Constructing data pipeline"); ILegacyDataLoader loader = CreateRawLoader(); // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform. var preXf = ImplOptions.PreTransforms; if (!string.IsNullOrEmpty(ImplOptions.OutputDataFile)) { string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(ImplOptions.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name); if (name == null) { preXf = preXf.Concat( new[] { new KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >( "", ComponentFactoryUtils.CreateFromFunction <IDataView, IDataTransform>( (env, input) => { var args = new GenerateNumberTransform.Options(); args.Columns = new[] { new GenerateNumberTransform.Column() { Name = DefaultColumnNames.Name }, }; args.UseCounter = true; return(new GenerateNumberTransform(env, args, input)); })) }).ToArray(); } } loader = LegacyCompositeDataLoader.Create(Host, loader, preXf); ch.Trace("Binding label and features columns"); IDataView pipe = loader; var stratificationColumn = GetSplitColumn(ch, loader, ref pipe); var scorer = ImplOptions.Scorer; var evaluator = ImplOptions.Evaluator; Func <IDataView> validDataCreator = null; if (ImplOptions.ValidationFile != null) { validDataCreator = () => { // Fork the command. var impl = new CrossValidationCommand(this); return(impl.CreateRawLoader(dataFile: ImplOptions.ValidationFile)); }; } FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn, ImplOptions, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator, validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(ImplOptions.OutputDataFile)); var tasks = fold.GetCrossValidationTasks(); var eval = evaluator?.CreateComponent(Host) ?? EvaluateUtils.GetEvaluator(Host, tasks[0].Result.ScoreSchema); // Print confusion matrix and fold results for each fold. for (int i = 0; i < tasks.Length; i++) { var dict = tasks[i].Result.Metrics; MetricWriter.PrintWarnings(ch, dict); eval.PrintFoldResults(ch, dict); } // Print the overall results. if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList)) { throw ch.Except("No overall metrics found"); } var overall = eval.GetOverallResults(overallList.ToArray()); MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, ImplOptions.NumFolds); eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray()); Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray(); SendTelemetryMetric(metricValues); // Save the per-instance results. if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile)) { var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, ImplOptions.CollateMetrics, ImplOptions.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames); if (variableSizeVectorColumnNames.Length > 0) { ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.", string.Join(", ", variableSizeVectorColumnNames)); } if (ImplOptions.CollateMetrics) { ch.Assert(perInstance.Length == 1); MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, perInstance[0]); } else { int i = 0; foreach (var idv in perInstance) { MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(ImplOptions.OutputDataFile, i), idv); i++; } } } }
private static ILegacyDataLoader ApplyTransformsCore(IHost host, ILegacyDataLoader srcLoader, KeyValuePair <string, string>[] tagData, Func <IHostEnvironment, int, IDataView, IDataView> createTransform) { Contracts.AssertValue(host, "host"); host.AssertValue(srcLoader, "srcLoader"); host.AssertNonEmpty(tagData); host.AssertValue(createTransform, "createTransform"); // If the loader is a composite, we need to start with its underlying pipeline end. var exes = new List <TransformEx>(); var composite = srcLoader as LegacyCompositeDataLoader; IDataView srcView; ILegacyDataLoader pipeStart; if (composite != null) { srcView = composite.View; exes.AddRange(composite._transforms); pipeStart = composite._loader; } else { srcView = pipeStart = srcLoader; } IDataView view = srcView; using (var ch = host.Start("Transforms")) { int count = Utils.Size(tagData); var newlyCreated = new List <TransformEx>(); for (int i = 0; i < count; i++) { // REVIEW: this might cause silent automatic tag conflicts if the pipeline is short-circuited. // Maybe it's better to allow empty tags? var tag = tagData[i].Key; if (string.IsNullOrEmpty(tag)) { tag = GenerateTag(exes.Count); } var newDataView = createTransform(host, i, view); // Append the newly created transforms to the exes list. // If the newTransform is a 'no-op' transform, i.e. equal to the original view, // the exes array will not be modified: there's no reason to record details of a no-op transform, // especially since this would overwrite the useful details of the upstream transform. newlyCreated.Clear(); IDataView curDataView = newDataView; while (true) { var cur = curDataView as IDataTransform; if (cur == null) { // We reached all the way back to the pipe start. The exes accumulated so far are irrelevant. ch.Check(curDataView == pipeStart, "The transform has corrupted the chain (chain no longer starts with the same loader)."); exes.Clear(); break; } int index = exes.FindLastIndex(x => x.Transform == cur); if (index >= 0) { // We found a transform in exes to attach to. if (index < exes.Count - 1) { // The transform short-circuited some of the existing ones, remove them. exes.RemoveRange(index + 1, exes.Count - index - 1); } break; } newlyCreated.Add(new TransformEx(tag, tagData[i].Value, cur)); curDataView = cur.Source; } newlyCreated.Reverse(); exes.AddRange(newlyCreated); view = newDataView; } } return(view == srcView ? srcLoader : new LegacyCompositeDataLoader(host, exes.ToArray())); }
private void Run(IChannel ch) { ILegacyDataLoader loader = null; IPredictor rawPred = null; IDataView view; RoleMappedSchema trainSchema = null; if (_model == null && _predictiveModel == null) { if (string.IsNullOrEmpty(ImplOptions.InputModelFile)) { loader = CreateLoader(); rawPred = null; trainSchema = null; Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor), "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specified."); } else { LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader); } view = loader; } else if (_model != null) { view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema)); } else { view = _predictiveModel.TransformModel.Apply(Host, new EmptyDataView(Host, _predictiveModel.TransformModel.InputSchema)); rawPred = _predictiveModel.Predictor; trainSchema = _predictiveModel.GetTrainingSchema(Host); } // Create the ONNX context for storing global information var assembly = System.Reflection.Assembly.GetExecutingAssembly(); var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion, ModelVersion, _domain, ImplOptions.OnnxVersion); // Get the transform chain. IDataView source; IDataView end; LinkedList <ITransformCanSaveOnnx> transforms; GetPipe(ctx, ch, view, out source, out end, out transforms); Host.Assert(transforms.Count == 0 || transforms.Last.Value == end); // If we have a predictor, try to get the scorer for it. if (rawPred != null) { RoleMappedData data; if (trainSchema != null) { data = new RoleMappedData(end, trainSchema.GetColumnRoleNames()); } else { // We had a predictor, but no roles stored in the model. Just suppose // default column names are OK, if present. data = new RoleMappedData(end, DefaultColumnNames.Label, DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true); } var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema); var scoreOnnx = scorePipe as ITransformCanSaveOnnx; if (scoreOnnx?.CanSaveOnnx(ctx) == true) { Host.Assert(scorePipe.Source == end); end = scorePipe; transforms.AddLast(scoreOnnx); if (rawPred.PredictionKind == PredictionKind.BinaryClassification || rawPred.PredictionKind == PredictionKind.MulticlassClassification) { // Check if the PredictedLabel Column is a KeyDataViewType and has KeyValue Annotations. // If it does, add a KeyToValueMappingTransformer, to enable NimbusML to get the values back // when using an ONNX model, as described in https://github.com/dotnet/machinelearning/pull/4841 var predictedLabelColumn = scorePipe.Schema.GetColumnOrNull(DefaultColumnNames.PredictedLabel); if (predictedLabelColumn.HasValue && HasKeyValues(predictedLabelColumn.Value)) { var outputData = new KeyToValueMappingTransformer(Host, DefaultColumnNames.PredictedLabel).Transform(scorePipe); end = outputData; transforms.AddLast(outputData as ITransformCanSaveOnnx); } } } else { Contracts.CheckUserArg(_loadPredictor != true, nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX."); ch.Warning("We do not know how to save the predictor as ONNX. Ignoring."); } } else { Contracts.CheckUserArg(_loadPredictor != true, nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present."); } // Convert back to values the KeyDataViewType "pass-through" columns // (i.e those that remained untouched by the model). This is done to enable NimbusML to get these values // as described in https://github.com/dotnet/machinelearning/pull/4841 var passThroughColumnNames = GetPassThroughKeyDataViewTypeColumnsNames(source, end); foreach (var name in passThroughColumnNames) { var outputData = new KeyToValueMappingTransformer(Host, name).Transform(end); end = outputData; transforms.AddLast(end as ITransformCanSaveOnnx); } var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop); using (var file = Host.CreateOutputFile(_outputModelPath)) using (var stream = file.CreateWriteStream()) model.WriteTo(stream); if (_outputJsonModelPath != null) { using (var file = Host.CreateOutputFile(_outputJsonModelPath)) using (var stream = file.CreateWriteStream()) using (var writer = new StreamWriter(stream)) { var parsedJson = JsonConvert.DeserializeObject(model.ToString()); writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented)); } } if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile)) { Contracts.Assert(loader != null); ch.Trace("Saving the data pipe"); // Should probably include "end"? SaveLoader(loader, ImplOptions.OutputModelFile); } }
private ILegacyDataLoader CreateTransformChain(ILegacyDataLoader loader) { return(LegacyCompositeDataLoader.Create(Host, loader, ImplOptions.Transforms)); }
/// <summary> /// Loads multiple artifacts of interest from the input model file, given the context /// established by the command line arguments. /// </summary> /// <param name="ch">The channel to which to provide output.</param> /// <param name="wantPredictor">Whether we want a predictor from the model file. If /// <c>false</c> we will not even attempt to load a predictor. If <c>null</c> we will /// load the predictor, if present. If <c>true</c> we will load the predictor, or fail /// noisily if we cannot.</param> /// <param name="predictor">The predictor in the model, or <c>null</c> if /// <paramref name="wantPredictor"/> was false, or <paramref name="wantPredictor"/> was /// <c>null</c> and no predictor was present.</param> /// <param name="wantTrainSchema">Whether we want the training schema. Unlike /// <paramref name="wantPredictor"/>, this has no "hard fail if not present" option. If /// this is <c>true</c>, it is still possible for <paramref name="trainSchema"/> to remain /// <c>null</c> if there were no role mappings, or pipeline.</param> /// <param name="trainSchema">The training schema if <paramref name="wantTrainSchema"/> /// is true, and there were role mappings stored in the model.</param> /// <param name="pipe">The data pipe constructed from the combination of the /// model and command line arguments.</param> protected void LoadModelObjects( IChannel ch, bool?wantPredictor, out IPredictor predictor, bool wantTrainSchema, out RoleMappedSchema trainSchema, out ILegacyDataLoader pipe) { // First handle the case where there is no input model file. // Everything must come from the command line. using (var file = Host.OpenInputFile(ImplOptions.InputModelFile)) using (var strm = file.OpenReadStream()) using (var rep = RepositoryReader.Open(strm, Host)) { // First consider loading the predictor. if (wantPredictor == false) { predictor = null; } else { ch.Trace("Loading predictor"); predictor = ModelFileUtils.LoadPredictorOrNull(Host, rep); if (wantPredictor == true) { Host.Check(predictor != null, "Could not load predictor from model file"); } } // Next create the loader. var loaderFactory = ImplOptions.Loader; ILegacyDataLoader trainPipe = null; if (loaderFactory != null) { // The loader is overridden from the command line. pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(ImplOptions.DataFile)); if (ImplOptions.LoadTransforms == true) { Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile)); pipe = LoadTransformChain(pipe); } } else { var loadTrans = ImplOptions.LoadTransforms ?? true; pipe = LoadLoader(rep, ImplOptions.DataFile, loadTrans); if (loadTrans) { trainPipe = pipe; } } if (Utils.Size(ImplOptions.Transforms) > 0) { pipe = LegacyCompositeDataLoader.Create(Host, pipe, ImplOptions.Transforms); } // Next consider loading the training data's role mapped schema. trainSchema = null; if (wantTrainSchema) { // First try to get the role mappings. var trainRoleMappings = ModelFileUtils.LoadRoleMappingsOrNull(Host, rep); if (trainRoleMappings != null) { // Next create the training schema. In the event that the loaded pipeline happens // to be the training pipe, we can just use that. If it differs, then we need to // load the full pipeline from the model, relying upon the fact that all loaders // can be loaded with no data at all, to get their schemas. if (trainPipe == null) { trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true); } trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings); } // If the role mappings are null, an alternative would be to fail. However the idea // is that the scorer should always still succeed, although perhaps with reduced // functionality, even when the training schema is null, since not all versions of // TLC models will have the role mappings preserved, I believe. And, we do want to // maintain backwards compatibility. } } }
protected void SaveLoader(ILegacyDataLoader loader, string path) { using (var file = Host.CreateOutputFile(path)) LoaderUtils.SaveLoader(loader, file); }
private void Run(IChannel ch) { ILegacyDataLoader loader = null; IPredictor rawPred = null; IDataView view; RoleMappedSchema trainSchema = null; if (_model == null) { if (string.IsNullOrEmpty(ImplOptions.InputModelFile)) { loader = CreateLoader(); rawPred = null; trainSchema = null; Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor), "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specifified."); } else { LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader); } view = loader; } else { view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema)); } // Create the ONNX context for storing global information var assembly = System.Reflection.Assembly.GetExecutingAssembly(); var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion, ModelVersion, _domain, ImplOptions.OnnxVersion); // Get the transform chain. IDataView source; IDataView end; LinkedList <ITransformCanSaveOnnx> transforms; GetPipe(ctx, ch, view, out source, out end, out transforms); Host.Assert(transforms.Count == 0 || transforms.Last.Value == end); // If we have a predictor, try to get the scorer for it. if (rawPred != null) { RoleMappedData data; if (trainSchema != null) { data = new RoleMappedData(end, trainSchema.GetColumnRoleNames()); } else { // We had a predictor, but no roles stored in the model. Just suppose // default column names are OK, if present. data = new RoleMappedData(end, DefaultColumnNames.Label, DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true); } var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema); var scoreOnnx = scorePipe as ITransformCanSaveOnnx; if (scoreOnnx?.CanSaveOnnx(ctx) == true) { Host.Assert(scorePipe.Source == end); end = scorePipe; transforms.AddLast(scoreOnnx); } else { Contracts.CheckUserArg(_loadPredictor != true, nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX."); ch.Warning("We do not know how to save the predictor as ONNX. Ignoring."); } } else { Contracts.CheckUserArg(_loadPredictor != true, nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present."); } var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop); using (var file = Host.CreateOutputFile(_outputModelPath)) using (var stream = file.CreateWriteStream()) model.WriteTo(stream); if (_outputJsonModelPath != null) { using (var file = Host.CreateOutputFile(_outputJsonModelPath)) using (var stream = file.CreateWriteStream()) using (var writer = new StreamWriter(stream)) { var parsedJson = JsonConvert.DeserializeObject(model.ToString()); writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented)); } } if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile)) { Contracts.Assert(loader != null); ch.Trace("Saving the data pipe"); // Should probably include "end"? SaveLoader(loader, ImplOptions.OutputModelFile); } }