public static TensorFlowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); var numInputs = ctx.Reader.ReadInt32(); Contracts.CheckDecode(numInputs > 0); string[] source = new string[numInputs]; for (int j = 0; j < source.Length; j++) { source[j] = ctx.LoadNonEmptyString(); } byte[] data = null; if (!ctx.TryLoadBinaryStream("TFModel", r => data = r.ReadByteArray())) { throw env.ExceptDecode(); } var outputColName = ctx.LoadNonEmptyString(); return(new TensorFlowMapper(env, schema, data, source, outputColName)); }
// Factory method for SignatureLoadModel. private static OnnxTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray())) throw env.ExceptDecode(); bool supportsMultiInputOutput = ctx.Header.ModelVerWritten > 0x00010001; var numInputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; env.CheckDecode(numInputs > 0); var inputs = new string[numInputs]; for (int j = 0; j < inputs.Length; j++) inputs[j] = ctx.LoadNonEmptyString(); var numOutputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; env.CheckDecode(numOutputs > 0); var outputs = new string[numOutputs]; for (int j = 0; j < outputs.Length; j++) outputs[j] = ctx.LoadNonEmptyString(); var args = new Arguments() { InputColumns = inputs, OutputColumns = outputs }; return new OnnxTransform(env, args, modelBytes); }
// Factory method for SignatureLoadModel. private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray())) { throw env.ExceptDecode(); } bool supportsMultiInputOutput = ctx.Header.ModelVerWritten > 0x00010001; var numInputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; env.CheckDecode(numInputs > 0); var inputs = new string[numInputs]; for (int j = 0; j < inputs.Length; j++) { inputs[j] = ctx.LoadNonEmptyString(); } var numOutputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; env.CheckDecode(numOutputs > 0); var outputs = new string[numOutputs]; for (int j = 0; j < outputs.Length; j++) { outputs[j] = ctx.LoadNonEmptyString(); } // Save custom-provided shapes. Those shapes overwrite shapes loaded from the ONNX model file. int customShapeInfosLength = ctx.Reader.ReadInt32(); // 0 means no custom shape. Non-zero means count of custom shapes. CustomShapeInfo[] loadedCustomShapeInfos = null; if (customShapeInfosLength > 0) { loadedCustomShapeInfos = new CustomShapeInfo[customShapeInfosLength]; for (int i = 0; i < customShapeInfosLength; ++i) { var name = ctx.LoadNonEmptyString(); var shape = ctx.Reader.ReadIntArray(); loadedCustomShapeInfos[i] = new CustomShapeInfo() { Name = name, Shape = shape }; } } var options = new Options() { InputColumns = inputs, OutputColumns = outputs, CustomShapeInfos = loadedCustomShapeInfos }; return(new OnnxTransformer(env, options, modelBytes)); }
private static void CheckBinaryLabel(bool user, IHostEnvironment env, IPredictorModel[] predictors) { int classCount = CheckLabelColumn(env, predictors, true); if (classCount != 2) { var error = string.Format("Expected label to have exactly 2 classes, instead has {0}", classCount); throw user?env.ExceptParam(nameof(predictors), error) : env.ExceptDecode(error); } }
/// <summary> /// Load from the given repository entry using the default loader(s) specified in the header. /// </summary> public static void LoadModel <TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra) where TRes : class { Contracts.CheckValue(env, nameof(env)); env.CheckValue(rep, nameof(rep)); if (!TryLoadModel <TRes, TSig>(env, out result, rep, ent, dir, extra)) { throw env.ExceptDecode("Couldn't load model: '{0}'", dir); } }
/// <summary> /// Load an object from the repository directory. /// </summary> public static void LoadModel <TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra) where TRes : class { Contracts.CheckValue(env, nameof(env)); env.CheckValue(rep, nameof(rep)); if (!LoadModelOrNull <TRes, TSig>(env, out result, rep, dir, extra)) { throw env.ExceptDecode("Corrupt model file"); } env.AssertValue(result); }
public static Bindings Create(IHostEnvironment env, ModelLoadContext ctx, Schema input, OptionalColumnTransform parent) { Contracts.AssertValue(ctx); Contracts.AssertValue(input); // *** Binary format *** // Schema of the data view containing the optional columns // int: number of added columns // for each added column // int: id of output column name // ColumnType: the type of the column byte[] buffer = null; if (!ctx.TryLoadBinaryStream("Schema.idv", r => buffer = r.ReadByteArray())) { throw env.ExceptDecode(); } BinaryLoader loader = null; var strm = new MemoryStream(buffer, writable: false); loader = new BinaryLoader(env, new BinaryLoader.Arguments(), strm); int size = ctx.Reader.ReadInt32(); Contracts.CheckDecode(size > 0); var saver = new BinarySaver(env, new BinarySaver.Arguments()); var names = new string[size]; var columnTypes = new ColumnType[size]; var srcCols = new int[size]; var srcColsWithOptionalColumn = new int[size]; for (int i = 0; i < size; i++) { names[i] = ctx.LoadNonEmptyString(); columnTypes[i] = saver.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream); int col; bool success = input.TryGetColumnIndex(names[i], out col); srcCols[i] = success ? col : -1; success = loader.Schema.TryGetColumnIndex(names[i], out var colWithOptionalColumn); env.CheckDecode(success); srcColsWithOptionalColumn[i] = colWithOptionalColumn; } return(new Bindings(parent, columnTypes, srcCols, srcColsWithOptionalColumn, input, loader.Schema, false, names)); }
// Factory method for SignatureLoadModel. private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // stream: tensorFlow model. // int: number of input columns // for each input column // int: id of int column name // int: number of output columns // for each output column // int: id of output column name byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) { throw env.ExceptDecode(); } var session = TensorFlowUtils.LoadTFSession(env, modelBytes); var numInputs = ctx.Reader.ReadInt32(); env.CheckDecode(numInputs > 0); string[] inputs = new string[numInputs]; for (int j = 0; j < inputs.Length; j++) { inputs[j] = ctx.LoadNonEmptyString(); } bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002; var numOutputs = 1; if (isMultiOutput) { numOutputs = ctx.Reader.ReadInt32(); } env.CheckDecode(numOutputs > 0); var outputs = new string[numOutputs]; for (int j = 0; j < outputs.Length; j++) { outputs[j] = ctx.LoadNonEmptyString(); } return(new TensorFlowTransform(env, session, inputs, outputs)); }
// Factory method for SignatureLoadModel. private static OnnxTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray())) { throw env.ExceptDecode(); } var inputColumn = ctx.LoadNonEmptyString(); var outputColumn = ctx.LoadNonEmptyString(); var args = new Arguments() { InputColumn = inputColumn, OutputColumn = outputColumn }; return(new OnnxTransform(env, args, modelBytes)); }
// Factory method for SignatureLoadModel. private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // byte: indicator for frozen models // stream: tensorFlow model. // int: number of input columns // for each input column // int: id of int column name // int: number of output columns // for each output column // int: id of output column name GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen); if (isFrozen) { byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) { throw env.ExceptDecode(); } return(new TensorFlowTransform(env, TensorFlowUtils.LoadTFSession(env, modelBytes), inputs, outputs, null, false)); } var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), RegistrationName + "_" + Guid.NewGuid())); TensorFlowUtils.CreateFolderWithAclIfNotExists(env, tempDirPath); try { var load = ctx.TryLoadBinaryStream("TFSavedModel", br => { int count = br.ReadInt32(); for (int n = 0; n < count; n++) { string relativeFile = br.ReadString(); long fileLength = br.ReadInt64(); string fullFilePath = Path.Combine(tempDirPath, relativeFile); string fullFileDir = Path.GetDirectoryName(fullFilePath); if (fullFileDir != tempDirPath) { TensorFlowUtils.CreateFolderWithAclIfNotExists(env, fullFileDir); } using (var fs = new FileStream(fullFilePath, FileMode.Create, FileAccess.Write)) { long actualRead = br.BaseStream.CopyRange(fs, fileLength); env.Assert(actualRead == fileLength); } } }); return(new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, tempDirPath), inputs, outputs, tempDirPath, true)); } catch (Exception) { TensorFlowUtils.DeleteFolderWithRetries(env, tempDirPath); throw; } }