public Ensemble(ModelLoadContext ctx, bool usingDefaultValues, bool categoricalSplits) { // REVIEW: Verify the contents of the ensemble, both during building, // and during deserialization. // *** Binary format *** // int: Number of trees // Regression trees (num trees of these) // double: Bias // int: Id to InputInitializationContent string, currently ignored _trees = new List <RegressionTree>(); int numTrees = ctx.Reader.ReadInt32(); Contracts.CheckDecode(numTrees >= 0); for (int t = 0; t < numTrees; ++t) { AddTree(RegressionTree.Load(ctx, usingDefaultValues, categoricalSplits)); } Bias = ctx.Reader.ReadDouble(); _firstInputInitializationContent = ctx.LoadStringOrNull(); }
public Bindings(ModelLoadContext ctx, ISchema schemaInput) { Contracts.AssertValue(ctx); Contracts.AssertValue(schemaInput); Input = schemaInput; // *** Binary format *** // byte: default HiddenColumnOption value // int: number of raw column infos // for each raw column info // int: id of output column name // int: id of input column name // byte: HiddenColumnOption HidDefault = (HiddenColumnOption)ctx.Reader.ReadByte(); Contracts.CheckDecode(Enum.IsDefined(typeof(HiddenColumnOption), HidDefault)); int count = ctx.Reader.ReadInt32(); Contracts.CheckDecode(count >= 0); RawInfos = new RawColInfo[count]; if (count > 0) { var names = new HashSet <string>(); for (int i = 0; i < count; i++) { string dst = ctx.LoadNonEmptyString(); Contracts.CheckDecode(names.Add(dst)); string src = ctx.LoadNonEmptyString(); var hid = (HiddenColumnOption)ctx.Reader.ReadByte(); Contracts.CheckDecode(Enum.IsDefined(typeof(HiddenColumnOption), hid)); RawInfos[i] = new RawColInfo(dst, src, hid); } } BuildInfos(out Infos, out NameToInfoIndex, user: false); }
/// <summary> /// Back-compatibilty function that handles loading the DropColumns Transform. /// </summary> private static SelectColumnsTransform LoadDropColumnsTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { // *** Binary format *** // int: sizeof(Float) // bindings int cbFloat = ctx.Reader.ReadInt32(); //env.CheckDecode(cbFloat == sizeof(Float)); // *** Binary format *** // bool: whether to keep (vs drop) the named columns // int: number of names // int[]: the ids of the names var keep = ctx.Reader.ReadBoolByte(); int count = ctx.Reader.ReadInt32(); Contracts.CheckDecode(count > 0); var names = new HashSet <string>(); for (int i = 0; i < count; i++) { string name = ctx.LoadNonEmptyString(); Contracts.CheckDecode(names.Add(name)); } string[] keepColumns = null; string[] dropColumns = null; if (keep) { keepColumns = names.ToArray(); } else { dropColumns = names.ToArray(); } return(new SelectColumnsTransform(env, keepColumns, dropColumns, keep)); }
public BindableMapper(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); _env = env; _env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // IFeatureContributionMapper: Predictor // int: topContributionsCount // int: bottomContributionsCount // bool: normalize // bool: stringify ctx.LoadModel <IFeatureContributionMapper, SignatureLoadModel>(env, out Predictor, ModelFileUtils.DirPredictor); GenericMapper = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null); _topContributionsCount = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 <= _topContributionsCount); _bottomContributionsCount = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 <= _bottomContributionsCount); _normalize = ctx.Reader.ReadBoolByte(); Stringify = ctx.Reader.ReadBoolByte(); }
public TransformInfo(ModelLoadContext ctx, bool readWeighting) { Contracts.AssertValue(ctx); // *** Binary format *** // int: NgramLength // int: SkipLength // int: Weighting Criteria (if readWeighting == true) // bool[NgramLength]: NonEmptyLevels NgramLength = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength); SkipLength = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength); Contracts.CheckDecode(NgramLength <= NgramBufferBuilder.MaxSkipNgramLength - SkipLength); if (readWeighting) { Weighting = (NgramExtractingEstimator.WeightingCriteria)ctx.Reader.ReadInt32(); } Contracts.CheckDecode(Enum.IsDefined(typeof(NgramExtractingEstimator.WeightingCriteria), Weighting)); NonEmptyLevels = ctx.Reader.ReadBoolArray(NgramLength); }
protected static KeyValuePair<RoleMappedSchema.ColumnRole, string>[] LoadBaseInfo( ModelLoadContext ctx, out string suffix) { // *** Binary format *** // int: id of the suffix // int: the number of input column roles // for each input column: // int: id of the column role // int: id of the column name suffix = ctx.LoadString(); var count = ctx.Reader.ReadInt32(); Contracts.CheckDecode(count >= 0); var columns = new KeyValuePair<RoleMappedSchema.ColumnRole, string>[count]; for (int i = 0; i < count; i++) { var role = ctx.LoadNonEmptyString(); var name = ctx.LoadNonEmptyString(); columns[i] = RoleMappedSchema.CreatePair(role, name); } return columns; }
private static void ProbCheckDecode(Double p) { Contracts.CheckDecode(0 <= p && p <= 1); }
private static void TValueCheckDecode(Double param, Double tvalue) { Contracts.CheckDecode(Math.Sign(param) == Math.Sign(tvalue)); }
public Bindings(ModelLoadContext ctx, DatabaseLoader parent) { Contracts.AssertValue(ctx); // *** Binary format *** // int: number of columns // foreach column: // int: id of column name // byte: DataKind // byte: bool of whether this is a key type // for a key type: // ulong: count for key range // int: number of segments // foreach segment: // string id: name // int: min // int: lim // byte: force vector (verWrittenCur: verIsVectorSupported) int cinfo = ctx.Reader.ReadInt32(); Contracts.CheckDecode(cinfo > 0); Infos = new ColInfo[cinfo]; for (int iinfo = 0; iinfo < cinfo; iinfo++) { string name = ctx.LoadNonEmptyString(); PrimitiveDataViewType itemType; var kind = (InternalDataKind)ctx.Reader.ReadByte(); Contracts.CheckDecode(Enum.IsDefined(typeof(InternalDataKind), kind)); bool isKey = ctx.Reader.ReadBoolByte(); if (isKey) { ulong count; Contracts.CheckDecode(KeyDataViewType.IsValidDataType(kind.ToType())); count = ctx.Reader.ReadUInt64(); Contracts.CheckDecode(0 < count); itemType = new KeyDataViewType(kind.ToType(), count); } else { itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind); } int cseg = ctx.Reader.ReadInt32(); Segment[] segs; if (cseg == 0) { segs = null; } else { Contracts.CheckDecode(cseg > 0); segs = new Segment[cseg]; for (int iseg = 0; iseg < cseg; iseg++) { string columnName = ctx.LoadStringOrNull(); int min = ctx.Reader.ReadInt32(); int lim = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 <= min && min < lim); bool forceVector = ctx.Reader.ReadBoolByte(); segs[iseg] = (columnName is null) ? new Segment(min, lim, forceVector) : new Segment(columnName, forceVector); } } // Note that this will throw if the segments are ill-structured, including the case // of multiple variable segments (since those segments will overlap and overlapping // segments are illegal). Infos[iinfo] = ColInfo.Create(name, itemType, segs, false); } OutputSchema = ComputeOutputSchema(); }
private void Loader(ModelLoadContext ctx) { Contracts.CheckDecode(true, "This message is suspicious"); }
/// <summary> /// Checks the validity of the header, reads the string table, etc. /// </summary> public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out string loaderAssemblyName, out Exception ex) { Contracts.CheckValue(reader, nameof(reader)); Contracts.Check(fpMin >= 0); if (!TryValidate(ref header, reader.BaseStream.Length - fpMin, out ex)) { strings = null; loaderAssemblyName = null; return(false); } try { long fpOrig = reader.FpCur(); StringBuilder sb = null; if (header.FpStringTable == 0) { // No strings. strings = null; } else { reader.Seek(header.FpStringTable + fpMin); Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin); long cstr = header.CbStringTable / sizeof(long); Contracts.Assert(cstr < int.MaxValue); long[] offsets = reader.ReadLongArray((int)cstr); Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin); Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars); strings = new string[cstr]; long offset = 0; sb = new StringBuilder(); for (int i = 0; i < offsets.Length; i++) { Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin); long offsetPrev = offset; offset = offsets[i]; Contracts.CheckDecode(offsetPrev <= offset & offset <= header.CbStringChars); Contracts.CheckDecode(offset % sizeof(char) == 0); long cch = (offset - offsetPrev) / sizeof(char); Contracts.CheckDecode(cch < int.MaxValue); sb.Clear(); for (long ich = 0; ich < cch; ich++) { sb.Append((char)reader.ReadUInt16()); } strings[i] = sb.ToString(); } Contracts.CheckDecode(offset == header.CbStringChars); Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin); } if (header.VerWritten >= VerAssemblyNameSupported && header.FpAssemblyName != 0) { reader.Seek(header.FpAssemblyName + fpMin); int assemblyNameLength = (int)header.CbAssemblyName / sizeof(char); sb = sb != null?sb.Clear() : new StringBuilder(assemblyNameLength); for (long ich = 0; ich < assemblyNameLength; ich++) { sb.Append((char)reader.ReadUInt16()); } loaderAssemblyName = sb.ToString(); } else { loaderAssemblyName = null; } Contracts.CheckDecode(header.FpTail == reader.FpCur() - fpMin); ulong tail = reader.ReadUInt64(); Contracts.CheckDecode(tail == TailSignatureValue, "Corrupt model file tail"); ex = null; reader.Seek(fpOrig); return(true); } catch (Exception e) { strings = null; loaderAssemblyName = null; ex = e; return(false); } }
/// <summary> /// Checks the basic validity of the header, assuming the stream is at least the given size. /// Returns false (and the out exception) on failure. /// </summary> public static bool TryValidate(ref ModelHeader header, long size, out Exception ex) { Contracts.Check(size >= 0); try { Contracts.CheckDecode(header.Signature == SignatureValue, "Wrong file type"); Contracts.CheckDecode(header.VerReadable <= header.VerWritten, "Corrupt file header"); Contracts.CheckDecode(header.VerReadable <= VerWrittenCur, "File is too new"); Contracts.CheckDecode(header.VerWritten >= VerWeCanReadBack, "File is too old"); // Currently the model always comes immediately after the header. Contracts.CheckDecode(header.FpModel == Size); Contracts.CheckDecode(header.FpModel + header.CbModel >= header.FpModel); if (header.FpStringTable == 0) { // No strings. Contracts.CheckDecode(header.CbStringTable == 0); Contracts.CheckDecode(header.FpStringChars == 0); Contracts.CheckDecode(header.CbStringChars == 0); if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0) { Contracts.CheckDecode(header.FpTail == header.FpModel + header.CbModel); } } else { // Currently the string table always comes immediately after the model block. Contracts.CheckDecode(header.FpStringTable == header.FpModel + header.CbModel); Contracts.CheckDecode(header.CbStringTable % sizeof(long) == 0); Contracts.CheckDecode(header.CbStringTable / sizeof(long) < int.MaxValue); Contracts.CheckDecode(header.FpStringTable + header.CbStringTable > header.FpStringTable); Contracts.CheckDecode(header.FpStringChars == header.FpStringTable + header.CbStringTable); Contracts.CheckDecode(header.CbStringChars % sizeof(char) == 0); Contracts.CheckDecode(header.FpStringChars + header.CbStringChars >= header.FpStringChars); if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0) { Contracts.CheckDecode(header.FpTail == header.FpStringChars + header.CbStringChars); } } if (header.VerWritten >= VerAssemblyNameSupported) { if (header.FpAssemblyName == 0) { Contracts.CheckDecode(header.CbAssemblyName == 0); } else { // the assembly name always immediately after the string table, if there is one if (header.FpStringTable == 0) { Contracts.CheckDecode(header.FpAssemblyName == header.FpModel + header.CbModel); } else { Contracts.CheckDecode(header.FpAssemblyName == header.FpStringChars + header.CbStringChars); } Contracts.CheckDecode(header.CbAssemblyName % sizeof(char) == 0); Contracts.CheckDecode(header.FpTail == header.FpAssemblyName + header.CbAssemblyName); } } Contracts.CheckDecode(header.FpLim == header.FpTail + sizeof(ulong)); Contracts.CheckDecode(size == 0 || size >= header.FpLim); ex = null; return(true); } catch (Exception e) { ex = e; return(false); } }
/// <summary> /// Finish reading. Checks that the current reader position is the end of the model blob. /// Seeks to the end of the entire model file (after the tail). /// </summary> public static void EndRead(long fpMin, ref ModelHeader header, BinaryReader reader) { Contracts.CheckDecode(header.FpModel + header.CbModel == reader.FpCur() - fpMin); reader.Seek(header.FpLim + fpMin); }
/// <summary> /// Checks the validity of the header, reads the string table, etc. /// </summary> public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out string loaderAssemblyName, out Exception ex) { Contracts.CheckValue(reader, nameof(reader)); Contracts.Check(fpMin >= 0); if (!TryValidate(ref header, reader.BaseStream.Length - fpMin, out ex)) { strings = null; loaderAssemblyName = null; return(false); } try { long fpOrig = reader.FpCur(); StringBuilder sb = null; if (header.FpStringTable == 0) { // No strings. strings = null; if (header.VerWritten < VerAssemblyNameSupported) { // Before VerAssemblyNameSupported, if there were no strings in the model, // validation ended here. Specifically the FpTail checks below were skipped. // There are earlier versions of models that don't have strings, and 'reader' is // not at FpTail at this point. // Preserve the previous behavior by returning early here. loaderAssemblyName = null; ex = null; return(true); } } else { reader.Seek(header.FpStringTable + fpMin); Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin); long cstr = header.CbStringTable / sizeof(long); Contracts.Assert(cstr < int.MaxValue); long[] offsets = reader.ReadLongArray((int)cstr); Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin); Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars); strings = new string[cstr]; long offset = 0; sb = new StringBuilder(); for (int i = 0; i < offsets.Length; i++) { Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin); long offsetPrev = offset; offset = offsets[i]; Contracts.CheckDecode(offsetPrev <= offset & offset <= header.CbStringChars); Contracts.CheckDecode(offset % sizeof(char) == 0); long cch = (offset - offsetPrev) / sizeof(char); Contracts.CheckDecode(cch < int.MaxValue); sb.Clear(); for (long ich = 0; ich < cch; ich++) { sb.Append((char)reader.ReadUInt16()); } strings[i] = sb.ToString(); } Contracts.CheckDecode(offset == header.CbStringChars); Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin); } if (header.VerWritten >= VerAssemblyNameSupported && header.FpAssemblyName != 0) { reader.Seek(header.FpAssemblyName + fpMin); int assemblyNameLength = (int)header.CbAssemblyName / sizeof(char); sb = sb != null?sb.Clear() : new StringBuilder(assemblyNameLength); for (long ich = 0; ich < assemblyNameLength; ich++) { sb.Append((char)reader.ReadUInt16()); } loaderAssemblyName = sb.ToString(); } else { loaderAssemblyName = null; } Contracts.CheckDecode(header.FpTail == reader.FpCur() - fpMin); ulong tail = reader.ReadUInt64(); Contracts.CheckDecode(tail == TailSignatureValue, "Corrupt model file tail"); ex = null; reader.Seek(fpOrig); return(true); } catch (Exception e) { strings = null; loaderAssemblyName = null; ex = e; return(false); } }