/// <summary> /// Create a <see cref="ModelSaveContext"/> supporting saving to a repository, for implementors of <see cref="ICanSaveModel"/>. /// </summary> internal ModelSaveContext(RepositoryWriter rep, string dir, string name) { Contracts.CheckValue(rep, nameof(rep)); Repository = rep; _ectx = rep.ExceptionContext; _ectx.CheckValueOrNull(dir); _ectx.CheckNonEmpty(name, nameof(name)); Directory = dir; Strings = new NormStr.Pool(); _ent = rep.CreateEntry(dir, name); try { Writer = new BinaryWriter(_ent.Stream, Encoding.UTF8, leaveOpen: true); try { ModelHeader.BeginWrite(Writer, out FpMin, out Header); } catch { Writer.Dispose(); throw; } } catch { _ent.Dispose(); throw; } }
/// <summary> /// Record the given alternate loader sig in the header. If sig is null, clears the alternate loader sig. /// </summary> public static void SetLoaderSigAlt(ref ModelHeader header, string sig) { header.LoaderSignatureAlt0 = 0; header.LoaderSignatureAlt1 = 0; header.LoaderSignatureAlt2 = 0; if (sig == null) { return; } Contracts.Check(sig.Length <= 24); for (int ich = 0; ich < sig.Length; ich++) { char ch = sig[ich]; Contracts.Check(ch <= 0xFF); if (ich < 8) { header.LoaderSignatureAlt0 |= (ulong)ch << (ich * 8); } else if (ich < 16) { header.LoaderSignatureAlt1 |= (ulong)ch << ((ich - 8) * 8); } else if (ich < 24) { header.LoaderSignatureAlt2 |= (ulong)ch << ((ich - 16) * 8); } } }
/// <summary> /// Extract and return the alternate loader sig from the header, trimming trailing zeros. /// </summary> public static string GetLoaderSigAlt(ref ModelHeader header) { char[] chars = new char[3 * sizeof(ulong)]; for (int ich = 0; ich < chars.Length; ich++) { char ch; if (ich < 8) { ch = (char)((header.LoaderSignatureAlt0 >> (ich * 8)) & 0xFF); } else if (ich < 16) { ch = (char)((header.LoaderSignatureAlt1 >> ((ich - 8) * 8)) & 0xFF); } else { ch = (char)((header.LoaderSignatureAlt2 >> ((ich - 16) * 8)) & 0xFF); } chars[ich] = ch; } int cch = 24; while (cch > 0 && chars[cch - 1] == 0) { cch--; } return(new string(chars, 0, cch)); }
/// <summary> /// The current writer position should be where the tail belongs. Writes the header and tail. /// Typically this isn't called directly unless you are doing custom string table serialization. /// In that case you should have called EndModelCore before writing the string table information. /// </summary> public static void WriteHeaderAndTailCore(BinaryWriter writer, long fpMin, ref ModelHeader header) { Contracts.CheckValue(writer, nameof(writer)); Contracts.CheckParam(fpMin >= 0, nameof(fpMin)); header.FpTail = writer.FpCur() - fpMin; writer.Write(TailSignatureValue); header.FpLim = writer.FpCur() - fpMin; Exception ex; bool res = TryValidate(ref header, header.FpLim, out ex); // If this fails, we didn't construct the header correctly. This is both a bug and // something we want to protect against at runtime, hence both assert and check. Contracts.Assert(res); Contracts.Check(res); // Write the header, then seek back to the end. writer.Seek(fpMin); byte[] headerBytes = new byte[ModelHeader.Size]; MarshalToBytes(ref header, headerBytes); writer.Write(headerBytes); Contracts.Assert(writer.FpCur() == fpMin + ModelHeader.Size); writer.Seek(header.FpLim + fpMin); }
/// <summary> /// Sets the version information the header. /// </summary> public static void SetVersionInfo(ref ModelHeader header, VersionInfo ver) { header.ModelSignature = ver.ModelSignature; header.ModelVerWritten = ver.VerWrittenCur; header.ModelVerReadable = ver.VerReadableCur; SetLoaderSig(ref header, ver.LoaderSignature); SetLoaderSigAlt(ref header, ver.LoaderSignatureAlt); }
/// <summary> /// Low level method for copying bytes from a byte array to a header structure. /// </summary> public static void MarshalFromBytes(out ModelHeader header, byte[] bytes) { Contracts.Check(Utils.Size(bytes) >= Size); unsafe { fixed(ModelHeader *pheader = &header) Marshal.Copy(bytes, 0, (IntPtr)pheader, Size); } }
/// <summary> /// Create a ModelLoadContext supporting loading from a single-stream, for implementors of ICanSaveInBinaryFormat. /// </summary> internal ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null) { Contracts.AssertValueOrNull(ectx); _ectx = ectx; _ectx.CheckValue(reader, nameof(reader)); Repository = null; Directory = null; Reader = reader; ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader); }
/// <summary> /// The current writer position should be the end of the model blob. Records the size of the model blob. /// Typically this isn't called directly unless you are doing custom string table serialization. /// </summary> public static void EndModelCore(BinaryWriter writer, long fpMin, ref ModelHeader header) { Contracts.Check(header.FpModel == ModelHeader.Size); Contracts.Check(header.CbModel == 0); long fpCur = writer.FpCur(); Contracts.Check(fpCur - fpMin >= header.FpModel); // Record the size of the model. header.CbModel = fpCur - header.FpModel - fpMin; }
/// <summary> /// Create a <see cref="ModelSaveContext"/> supporting saving to a single-stream, for implementors of <see cref="ICanSaveInBinaryFormat"/>. /// </summary> internal ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null) { Contracts.AssertValueOrNull(ectx); _ectx = ectx; _ectx.CheckValue(writer, nameof(writer)); Repository = null; Directory = null; _ent = null; Strings = new NormStr.Pool(); Writer = writer; ModelHeader.BeginWrite(Writer, out FpMin, out Header); }
/// <summary> /// The current writer position should be the end of the model blob. Records the model size, writes the string table, /// completes and writes the header, and writes the tail. /// </summary> public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader header, NormStr.Pool pool = null, string loaderAssemblyName = null) { Contracts.CheckValue(writer, nameof(writer)); Contracts.CheckParam(fpMin >= 0, nameof(fpMin)); Contracts.CheckValueOrNull(pool); // Record the model size. EndModelCore(writer, fpMin, ref header); Contracts.Check(header.FpStringTable == 0); Contracts.Check(header.CbStringTable == 0); Contracts.Check(header.FpStringChars == 0); Contracts.Check(header.CbStringChars == 0); // Write the strings. if (pool != null && pool.Count > 0) { header.FpStringTable = writer.FpCur() - fpMin; long offset = 0; int cv = 0; // REVIEW: Implement an indexer on pool! foreach (var ns in pool) { Contracts.Assert(ns.Id == cv); offset += ns.Value.Length * sizeof(char); writer.Write(offset); cv++; } Contracts.Assert(cv == pool.Count); header.CbStringTable = pool.Count * sizeof(long); header.FpStringChars = writer.FpCur() - fpMin; Contracts.Assert(header.FpStringChars == header.FpStringTable + header.CbStringTable); foreach (var ns in pool) { foreach (var ch in ns.Value.Span) { writer.Write((short)ch); } } header.CbStringChars = writer.FpCur() - header.FpStringChars - fpMin; Contracts.Assert(offset == header.CbStringChars); } WriteLoaderAssemblyName(writer, fpMin, ref header, loaderAssemblyName); WriteHeaderAndTailCore(writer, fpMin, ref header); }
// Utilities for writing. /// <summary> /// Initialize the header and writer for writing. The value of fpMin and header /// should be passed to the other utility methods here. /// </summary> public static void BeginWrite(BinaryWriter writer, out long fpMin, out ModelHeader header) { Contracts.Assert(Marshal.SizeOf(typeof(ModelHeader)) == Size); Contracts.CheckValue(writer, nameof(writer)); fpMin = writer.FpCur(); header = default(ModelHeader); header.Signature = SignatureValue; header.VerWritten = VerWrittenCur; header.VerReadable = VerReadableCur; header.FpModel = ModelHeader.Size; // Write a blank header - the correct information is written by WriteHeaderAndTail. byte[] headerBytes = new byte[ModelHeader.Size]; writer.Write(headerBytes); Contracts.CheckIO(writer.FpCur() == fpMin + ModelHeader.Size); }
// Utilities for reading. /// <summary> /// Read the model header, strings, etc from reader. Also validates the header (throws if bad). /// Leaves the reader position at the beginning of the model blob. /// </summary> public static void BeginRead(out long fpMin, out ModelHeader header, out string[] strings, out string loaderAssemblyName, BinaryReader reader) { fpMin = reader.FpCur(); byte[] headerBytes = reader.ReadBytes(ModelHeader.Size); Contracts.CheckDecode(headerBytes.Length == ModelHeader.Size); ModelHeader.MarshalFromBytes(out header, headerBytes); Exception ex; if (!ModelHeader.TryValidate(ref header, reader, fpMin, out strings, out loaderAssemblyName, out ex)) { throw ex; } reader.Seek(header.FpModel + fpMin); }
/// <summary> /// Tries to load. /// Returns false iff the default loader(s) could not be bound to a compatible loadable class. /// </summary> private bool TryLoadModelCore <TRes, TSig>(IHostEnvironment env, out TRes result, params object[] extra) where TRes : class { _ectx.AssertValue(env, "env"); _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel); var args = ConcatArgsRev(extra, this); EnsureLoaderAssemblyIsRegistered(env.ComponentCatalog); object tmp; string sig = ModelHeader.GetLoaderSig(ref Header); if (!string.IsNullOrWhiteSpace(sig) && ComponentCatalog.TryCreateInstance <object, TSig>(env, out tmp, sig, "", args)) { result = tmp as TRes; if (result != null) { Done(); return(true); } // REVIEW: Should this fall through? } _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel); string sigAlt = ModelHeader.GetLoaderSigAlt(ref Header); if (!string.IsNullOrWhiteSpace(sigAlt) && ComponentCatalog.TryCreateInstance <object, TSig>(env, out tmp, sigAlt, "", args)) { result = tmp as TRes; if (result != null) { Done(); return(true); } // REVIEW: Should this fall through? } _ectx.Assert(Reader.BaseStream.Position == FpMin + Header.FpModel); Reader.BaseStream.Position = FpMin; result = null; return(false); }
private static void WriteLoaderAssemblyName(BinaryWriter writer, long fpMin, ref ModelHeader header, string loaderAssemblyName) { if (!string.IsNullOrEmpty(loaderAssemblyName)) { header.FpAssemblyName = writer.FpCur() - fpMin; header.CbAssemblyName = (uint)loaderAssemblyName.Length * sizeof(char); foreach (var ch in loaderAssemblyName) { writer.Write((short)ch); } } else { header.FpAssemblyName = 0; header.CbAssemblyName = 0; } }
/// <summary> /// Performs standard version validation. /// </summary> public static void CheckVersionInfo(ref ModelHeader header, VersionInfo ver) { Contracts.CheckDecode(header.ModelSignature == ver.ModelSignature, "Unknown file type"); Contracts.CheckDecode(header.ModelVerReadable <= header.ModelVerWritten, "Corrupt file header"); if (header.ModelVerReadable > ver.VerWrittenCur) { throw Contracts.ExceptDecode("Cause: ML.NET {0} cannont read component '{1}' of the model, because the model is too new.\n" + "Suggestion: Make sure the model is trained with ML.NET {0} or older.\n" + "Debug details: Maximum expected version {2}, got {3}.", typeof(VersionInfo).Assembly.GetName().Version, ver.LoaderSignature, header.ModelVerReadable, ver.VerWrittenCur); } if (header.ModelVerWritten < ver.VerWeCanReadBack) { // Breaking backwards compatibility is something we should avoid if at all possible. If // this message is observed, it may be a bug. throw Contracts.ExceptDecode("Cause: ML.NET {0} cannot read component '{1}' of the model, because the model is too old.\n" + "Suggestion: Make sure the model is trained with ML.NET {0}.\n" + "Debug details: Minimum expected version {2}, got {3}.", typeof(VersionInfo).Assembly.GetName().Version, ver.LoaderSignature, header.ModelVerReadable, ver.VerWrittenCur); } }
/// <summary> /// Create a ModelLoadContext supporting loading from a repository, for implementors of ICanSaveModel. /// </summary> internal ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir) { Contracts.CheckValue(rep, nameof(rep)); Repository = rep; _ectx = rep.ExceptionContext; _ectx.CheckValue(ent, nameof(ent)); _ectx.CheckValueOrNull(dir); Directory = dir; Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true); try { ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader); } catch { Reader.Dispose(); throw; } }
internal void Done() { _ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!"); ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName); Dispose(); }
internal void SetVersionInfo(VersionInfo ver) { ModelHeader.SetVersionInfo(ref Header, ver); _loaderAssemblyName = ver.LoaderAssemblyName; }
public void CheckAtModel(VersionInfo ver) { _ectx.Check(Reader.BaseStream.Position == FpMin + Header.FpModel); ModelHeader.CheckVersionInfo(ref Header, ver); }
/// <summary> /// Checks the validity of the header, reads the string table, etc. /// </summary> public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out string loaderAssemblyName, out Exception ex) { Contracts.CheckValue(reader, nameof(reader)); Contracts.Check(fpMin >= 0); if (!TryValidate(ref header, reader.BaseStream.Length - fpMin, out ex)) { strings = null; loaderAssemblyName = null; return(false); } try { long fpOrig = reader.FpCur(); StringBuilder sb = null; if (header.FpStringTable == 0) { // No strings. strings = null; if (header.VerWritten < VerAssemblyNameSupported) { // Before VerAssemblyNameSupported, if there were no strings in the model, // validation ended here. Specifically the FpTail checks below were skipped. // There are earlier versions of models that don't have strings, and 'reader' is // not at FpTail at this point. // Preserve the previous behavior by returning early here. loaderAssemblyName = null; ex = null; return(true); } } else { reader.Seek(header.FpStringTable + fpMin); Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin); long cstr = header.CbStringTable / sizeof(long); Contracts.Assert(cstr < int.MaxValue); long[] offsets = reader.ReadLongArray((int)cstr); Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin); Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars); strings = new string[cstr]; long offset = 0; sb = new StringBuilder(); for (int i = 0; i < offsets.Length; i++) { Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin); long offsetPrev = offset; offset = offsets[i]; Contracts.CheckDecode(offsetPrev <= offset & offset <= header.CbStringChars); Contracts.CheckDecode(offset % sizeof(char) == 0); long cch = (offset - offsetPrev) / sizeof(char); Contracts.CheckDecode(cch < int.MaxValue); sb.Clear(); for (long ich = 0; ich < cch; ich++) { sb.Append((char)reader.ReadUInt16()); } strings[i] = sb.ToString(); } Contracts.CheckDecode(offset == header.CbStringChars); Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin); } if (header.VerWritten >= VerAssemblyNameSupported && header.FpAssemblyName != 0) { reader.Seek(header.FpAssemblyName + fpMin); int assemblyNameLength = (int)header.CbAssemblyName / sizeof(char); sb = sb != null?sb.Clear() : new StringBuilder(assemblyNameLength); for (long ich = 0; ich < assemblyNameLength; ich++) { sb.Append((char)reader.ReadUInt16()); } loaderAssemblyName = sb.ToString(); } else { loaderAssemblyName = null; } Contracts.CheckDecode(header.FpTail == reader.FpCur() - fpMin); ulong tail = reader.ReadUInt64(); Contracts.CheckDecode(tail == TailSignatureValue, "Corrupt model file tail"); ex = null; reader.Seek(fpOrig); return(true); } catch (Exception e) { strings = null; loaderAssemblyName = null; ex = e; return(false); } }
/// <summary> /// Checks the basic validity of the header, assuming the stream is at least the given size. /// Returns false (and the out exception) on failure. /// </summary> public static bool TryValidate(ref ModelHeader header, long size, out Exception ex) { Contracts.Check(size >= 0); try { Contracts.CheckDecode(header.Signature == SignatureValue, "Wrong file type"); Contracts.CheckDecode(header.VerReadable <= header.VerWritten, "Corrupt file header"); Contracts.CheckDecode(header.VerReadable <= VerWrittenCur, "File is too new"); Contracts.CheckDecode(header.VerWritten >= VerWeCanReadBack, "File is too old"); // Currently the model always comes immediately after the header. Contracts.CheckDecode(header.FpModel == Size); Contracts.CheckDecode(header.FpModel + header.CbModel >= header.FpModel); if (header.FpStringTable == 0) { // No strings. Contracts.CheckDecode(header.CbStringTable == 0); Contracts.CheckDecode(header.FpStringChars == 0); Contracts.CheckDecode(header.CbStringChars == 0); if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0) { Contracts.CheckDecode(header.FpTail == header.FpModel + header.CbModel); } } else { // Currently the string table always comes immediately after the model block. Contracts.CheckDecode(header.FpStringTable == header.FpModel + header.CbModel); Contracts.CheckDecode(header.CbStringTable % sizeof(long) == 0); Contracts.CheckDecode(header.CbStringTable / sizeof(long) < int.MaxValue); Contracts.CheckDecode(header.FpStringTable + header.CbStringTable > header.FpStringTable); Contracts.CheckDecode(header.FpStringChars == header.FpStringTable + header.CbStringTable); Contracts.CheckDecode(header.CbStringChars % sizeof(char) == 0); Contracts.CheckDecode(header.FpStringChars + header.CbStringChars >= header.FpStringChars); if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0) { Contracts.CheckDecode(header.FpTail == header.FpStringChars + header.CbStringChars); } } if (header.VerWritten >= VerAssemblyNameSupported) { if (header.FpAssemblyName == 0) { Contracts.CheckDecode(header.CbAssemblyName == 0); } else { // the assembly name always immediately after the string table, if there is one if (header.FpStringTable == 0) { Contracts.CheckDecode(header.FpAssemblyName == header.FpModel + header.CbModel); } else { Contracts.CheckDecode(header.FpAssemblyName == header.FpStringChars + header.CbStringChars); } Contracts.CheckDecode(header.CbAssemblyName % sizeof(char) == 0); Contracts.CheckDecode(header.FpTail == header.FpAssemblyName + header.CbAssemblyName); } } Contracts.CheckDecode(header.FpLim == header.FpTail + sizeof(ulong)); Contracts.CheckDecode(size == 0 || size >= header.FpLim); ex = null; return(true); } catch (Exception e) { ex = e; return(false); } }
/// <summary> /// Performs version checks. /// </summary> public void CheckVersionInfo(VersionInfo ver) { ModelHeader.CheckVersionInfo(ref Header, ver); }
/// <summary> /// Finish reading. Checks that the current reader position is the end of the model blob. /// Seeks to the end of the entire model file (after the tail). /// </summary> public static void EndRead(long fpMin, ref ModelHeader header, BinaryReader reader) { Contracts.CheckDecode(header.FpModel + header.CbModel == reader.FpCur() - fpMin); reader.Seek(header.FpLim + fpMin); }
/// <summary> /// Commit the load operation. This completes reading of the main stream. When in repository /// mode, it disposes the Reader (but not the repository). /// </summary> public void Done() { ModelHeader.EndRead(FpMin, ref Header, Reader); Dispose(); }