private static TFSession CheckFileAndRead(IHostEnvironment env, string modelFile) { env.CheckNonWhiteSpace(modelFile, nameof(modelFile)); env.CheckUserArg(File.Exists(modelFile), nameof(modelFile)); var bytes = File.ReadAllBytes(modelFile); return(TensorFlowUtils.LoadTFSession(env, bytes, modelFile)); }
// Factory method for SignatureLoadModel. private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // stream: tensorFlow model. // int: number of input columns // for each input column // int: id of int column name // int: number of output columns // for each output column // int: id of output column name byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) { throw env.ExceptDecode(); } var session = TensorFlowUtils.LoadTFSession(env, modelBytes); var numInputs = ctx.Reader.ReadInt32(); env.CheckDecode(numInputs > 0); string[] inputs = new string[numInputs]; for (int j = 0; j < inputs.Length; j++) { inputs[j] = ctx.LoadNonEmptyString(); } bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002; var numOutputs = 1; if (isMultiOutput) { numOutputs = ctx.Reader.ReadInt32(); } env.CheckDecode(numOutputs > 0); var outputs = new string[numOutputs]; for (int j = 0; j < outputs.Length; j++) { outputs[j] = ctx.LoadNonEmptyString(); } return(new TensorFlowTransform(env, session, inputs, outputs)); }
// Factory method for SignatureLoadModel. private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // byte: indicator for frozen models // stream: tensorFlow model. // int: number of input columns // for each input column // int: id of int column name // int: number of output columns // for each output column // int: id of output column name GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen); if (isFrozen) { byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) { throw env.ExceptDecode(); } return(new TensorFlowTransform(env, TensorFlowUtils.LoadTFSession(env, modelBytes), inputs, outputs, null, false)); } var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), RegistrationName + "_" + Guid.NewGuid())); TensorFlowUtils.CreateFolderWithAclIfNotExists(env, tempDirPath); try { var load = ctx.TryLoadBinaryStream("TFSavedModel", br => { int count = br.ReadInt32(); for (int n = 0; n < count; n++) { string relativeFile = br.ReadString(); long fileLength = br.ReadInt64(); string fullFilePath = Path.Combine(tempDirPath, relativeFile); string fullFileDir = Path.GetDirectoryName(fullFilePath); if (fullFileDir != tempDirPath) { TensorFlowUtils.CreateFolderWithAclIfNotExists(env, fullFileDir); } using (var fs = new FileStream(fullFilePath, FileMode.Create, FileAccess.Write)) { long actualRead = br.BaseStream.CopyRange(fs, fileLength); env.Assert(actualRead == fileLength); } } }); return(new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, tempDirPath), inputs, outputs, tempDirPath, true)); } catch (Exception) { TensorFlowUtils.DeleteFolderWithRetries(env, tempDirPath); throw; } }