public static Bindings Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema input, OptionalColumnTransform parent) { Contracts.AssertValue(ctx); Contracts.AssertValue(input); // *** Binary format *** // Schema of the data view containing the optional columns // int: number of added columns // for each added column // int: id of output column name // ColumnType: the type of the column byte[] buffer = null; if (!ctx.TryLoadBinaryStream("Schema.idv", r => buffer = r.ReadByteArray())) { throw env.ExceptDecode(); } BinaryLoader loader = null; var strm = new MemoryStream(buffer, writable: false); loader = new BinaryLoader(env, new BinaryLoader.Arguments(), strm); int size = ctx.Reader.ReadInt32(); Contracts.CheckDecode(size > 0); var saver = new BinarySaver(env, new BinarySaver.Arguments()); var names = new string[size]; var columnTypes = new DataViewType[size]; var srcCols = new int[size]; var srcColsWithOptionalColumn = new int[size]; for (int i = 0; i < size; i++) { names[i] = ctx.LoadNonEmptyString(); columnTypes[i] = saver.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream); int col; bool success = input.TryGetColumnIndex(names[i], out col); srcCols[i] = success ? col : -1; success = loader.Schema.TryGetColumnIndex(names[i], out var colWithOptionalColumn); env.CheckDecode(success); srcColsWithOptionalColumn[i] = colWithOptionalColumn; } return(new Bindings(parent, columnTypes, srcCols, srcColsWithOptionalColumn, input, loader.Schema, false, names)); }
public NAFilter(IHost host, ModelLoadContext ctx, IDataView input) : base(host, input) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // int: sizeof(Float) // int: number of columns // int[]: ids of column names int cbFloat = ctx.Reader.ReadInt32(); Host.CheckDecode(cbFloat == sizeof(Single) || cbFloat == sizeof(Double)); int cinfo = ctx.Reader.ReadInt32(); Host.CheckDecode(cinfo > 0); _infos = new ColInfo[cinfo]; _srcIndexToInfoIndex = new Dictionary <int, int>(_infos.Length); var schema = Source.Schema; for (int i = 0; i < cinfo; i++) { string src = ctx.LoadNonEmptyString(); int index; if (!schema.TryGetColumnIndex(src, out index)) { throw Host.Except("Source column '{0}' not found", src); } if (_srcIndexToInfoIndex.ContainsKey(index)) { throw Host.Except("Source column '{0}' specified multiple times", src); } var type = schema.GetColumnType(index); if (!TestType(type)) { throw Host.Except("Column '{0}' does not have compatible numeric type", src); } _infos[i] = new ColInfo(index, type); _srcIndexToInfoIndex.Add(index, i); } }
private LabelNameBindableMapper(IHost host, ModelLoadContext ctx) { Contracts.AssertValue(host); _host = host; _host.AssertValue(ctx); ctx.LoadModel <ISchemaBindableMapper, SignatureLoadModel>(_host, out _bindable, _innerDir); BinarySaver saver = new BinarySaver(_host, new BinarySaver.Arguments()); ColumnType type; object value; _host.CheckDecode(saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out type, out value)); _host.CheckDecode(type.IsVector); _host.CheckDecode(value != null); _type = type.AsVector; _getter = Utils.MarshalInvoke(DecodeInit <int>, _type.ItemType.RawType, value); _metadataKind = ctx.Header.ModelVerReadable >= VersionAddedMetadataKind? ctx.LoadNonEmptyString() : MetadataUtils.Kinds.SlotNames; }
private QuantileRegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) : base(env, ctx, schema) { CheckInputColumnTypes(schema); // *** Binary format ** // base // int: _scoreSize // int[]: Ids of the quantile names _scoreSize = ctx.Reader.ReadInt32(); Host.CheckDecode(_scoreSize > 0); _quantiles = new DvText[_scoreSize]; for (int i = 0; i < _scoreSize; i++) { _quantiles[i] = new DvText(ctx.LoadNonEmptyString()); } _outputType = new VectorType(NumberType.R8, _scoreSize); }
protected PerInstanceEvaluatorBase(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) { Host = env.Register("PerInstanceRowMapper"); // *** Binary format ** // int: Id of the score column name // int: Id of the label column name ScoreCol = ctx.LoadNonEmptyString(); LabelCol = ctx.LoadStringOrNull(); if (!string.IsNullOrEmpty(LabelCol) && !schema.TryGetColumnIndex(LabelCol, out LabelIndex)) { throw Host.Except($"Did not find the label column '{LabelCol}'"); } if (!schema.TryGetColumnIndex(ScoreCol, out ScoreIndex)) { throw Host.Except($"Did not find column '{ScoreCol}'"); } }
/// <summary> /// Back-compatibility function that handles loading the DropColumns Transform. /// </summary> private static ColumnSelectingTransformer LoadDropColumnsTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { // *** Binary format *** // int: sizeof(Float) // bindings int cbFloat = ctx.Reader.ReadInt32(); env.CheckDecode(cbFloat == sizeof(float)); // *** Binary format *** // bool: whether to keep (vs drop) the named columns // int: number of names // int[]: the ids of the names var keep = ctx.Reader.ReadBoolByte(); int count = ctx.Reader.ReadInt32(); Contracts.CheckDecode(count > 0); var names = new HashSet <string>(); for (int i = 0; i < count; i++) { string name = ctx.LoadNonEmptyString(); Contracts.CheckDecode(names.Add(name)); } string[] keepColumns = null; string[] dropColumns = null; if (keep) { keepColumns = names.ToArray(); } else { dropColumns = names.ToArray(); } // Note for backward compatibility, Drop/Keep Columns always preserves // hidden columns return(new ColumnSelectingTransformer(env, keepColumns, dropColumns, true)); }
private RangeFilter(IHost host, ModelLoadContext ctx, IDataView input) : base(host, input) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // int: sizeof(Float) // int: id of column name // double: min // double: max // byte: complement int cbFloat = ctx.Reader.ReadInt32(); Host.CheckDecode(cbFloat == sizeof(float)); var column = ctx.LoadNonEmptyString(); var schema = Source.Schema; if (!schema.TryGetColumnIndex(column, out _index)) { throw Host.ExceptSchemaMismatch(nameof(schema), "source", column); } _type = schema[_index].Type; if (_type != NumberDataViewType.Single && _type != NumberDataViewType.Double && _type.GetKeyCount() == 0) { throw Host.ExceptSchemaMismatch(nameof(schema), "source", column, "Single, Double or Key", _type.ToString()); } _min = ctx.Reader.ReadDouble(); _max = ctx.Reader.ReadDouble(); if (!(_min <= _max)) { throw Host.Except("min", "min must be less than or equal to max"); } _complement = ctx.Reader.ReadBoolByte(); _includeMin = ctx.Reader.ReadBoolByte(); _includeMax = ctx.Reader.ReadBoolByte(); }
private RangeFilter(IHost host, ModelLoadContext ctx, IDataView input) : base(host, input) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // int: sizeof(Float) // int: id of column name // double: min // double: max // byte: complement int cbFloat = ctx.Reader.ReadInt32(); Host.CheckDecode(cbFloat == sizeof(Float)); var column = ctx.LoadNonEmptyString(); var schema = Source.Schema; if (!schema.TryGetColumnIndex(column, out _index)) { throw Host.Except("column", "Source column '{0}' not found", column); } _type = schema.GetColumnType(_index); if (_type != NumberType.R4 && _type != NumberType.R8 && _type.KeyCount == 0) { throw Host.Except("column", "Column '{0}' does not have compatible type", column); } _min = ctx.Reader.ReadDouble(); _max = ctx.Reader.ReadDouble(); if (!(_min <= _max)) { throw Host.Except("min", "min must be less than or equal to max"); } _complement = ctx.Reader.ReadBoolByte(); _includeMin = ctx.Reader.ReadBoolByte(); _includeMax = ctx.Reader.ReadBoolByte(); }
private WordEmbeddingsTransform(IHost host, ModelLoadContext ctx, IDataView input) : base(host, ctx, input, TestIsTextVector) { Host.AssertValue(ctx); Host.AssertNonEmpty(Infos); _customLookup = ctx.Reader.ReadBoolByte(); if (_customLookup) { _modelFileNameWithPath = ctx.LoadNonEmptyString(); _modelKind = null; } else { _modelKind = (PretrainedModelKind)ctx.Reader.ReadUInt32(); _modelFileNameWithPath = EnsureModelFile(Host, out _linesToSkip, (PretrainedModelKind)_modelKind); } Host.CheckNonWhiteSpace(_modelFileNameWithPath, nameof(_modelFileNameWithPath)); _currentVocab = GetVocabularyDictionary(); _outputType = new VectorType(NumberType.R4, 3 * _currentVocab.Dimension); Metadata.Seal(); }
public static SchemaImpl Create(ModelLoadContext ctx, IExceptionContext ectx, Schema inputSchema) { Contracts.AssertValueOrNull(ectx); ectx.AssertValue(ctx); ectx.AssertValue(inputSchema); // *** Binary format *** // int: ungroup mode // int: K - number of pivot columns // int[K]: ids of pivot column names int modeIndex = ctx.Reader.ReadInt32(); ectx.CheckDecode(Enum.IsDefined(typeof(UngroupMode), modeIndex)); UngroupMode mode = (UngroupMode)modeIndex; int k = ctx.Reader.ReadInt32(); ectx.CheckDecode(k > 0); var pivotColumns = new string[k]; for (int i = 0; i < k; i++) pivotColumns[i] = ctx.LoadNonEmptyString(); return new SchemaImpl(ectx, inputSchema, mode, pivotColumns); }
protected SchemaBindablePipelineEnsembleBase(IHostEnvironment env, ModelLoadContext ctx, string scoreColumnKind) { Host = env.Register(LoaderSignature); Host.AssertNonEmpty(scoreColumnKind); _scoreColumnKind = scoreColumnKind; // *** Binary format *** // int: id of _scoreColumnKind (loaded in the Create method) // int: number of predictors // The predictor models // int: the number of input columns // for each input column: // int: id of the column name var length = ctx.Reader.ReadInt32(); Host.CheckDecode(length > 0); PredictorModels = new IPredictorModel[length]; for (int i = 0; i < PredictorModels.Length; i++) { string dir = ctx.Header.ModelVerWritten == 0x00010001 ? "PredictorModels" : Path.Combine(ctx.Directory, "PredictorModels"); using (var ent = ctx.Repository.OpenEntry(dir, $"PredictorModel_{i:000}")) PredictorModels[i] = new PredictorModel(Host, ent.Stream); } length = ctx.Reader.ReadInt32(); Host.CheckDecode(length >= 0); _inputCols = new string[length]; for (int i = 0; i < length; i++) { _inputCols[i] = ctx.LoadNonEmptyString(); } }
public static SchemaBindablePipelineEnsembleBase Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); var scoreColumnKind = ctx.LoadNonEmptyString(); switch (scoreColumnKind) { case MetadataUtils.Const.ScoreColumnKind.BinaryClassification: return(new ImplOneWithCalibrator(env, ctx, scoreColumnKind)); case MetadataUtils.Const.ScoreColumnKind.Regression: case MetadataUtils.Const.ScoreColumnKind.AnomalyDetection: return(new ImplOne(env, ctx, scoreColumnKind)); case MetadataUtils.Const.ScoreColumnKind.MultiClassClassification: return(new ImplVec(env, ctx, scoreColumnKind)); default: throw env.Except("Unknown score kind"); } }
/// <summary> /// Loads all transforms from the <paramref name="ctx"/> that pass the <paramref name="isTransformTagAccepted"/> test, /// applies them sequentially to the <paramref name="srcView"/>, and returns the resulting data view. /// If there are no transforms in <paramref name="ctx"/> that are accepted, returns the original <paramref name="srcView"/>. /// The difference from the <c>Create</c> method above is that: /// - it doesn't wrap the results into a loader, just returns the last transform in the chain. /// - it accepts <see cref="IDataView"/> as input, not necessarily a loader. /// - it throws away the tag information. /// - it doesn't throw if the context is not representing a <see cref="CompositeDataLoader"/>: in this case it's assumed that no transforms /// meet the test, and the <paramref name="srcView"/> is returned. /// Essentially, this is a helper method for the LoadTransform class. /// </summary> public static IDataView LoadSelectedTransforms(ModelLoadContext ctx, IDataView srcView, IHostEnvironment env, Func <string, bool> isTransformTagAccepted) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); h.CheckValue(ctx, nameof(ctx)); h.Check(ctx.Reader.BaseStream.Position == ctx.FpMin + ctx.Header.FpModel); var ver = GetVersionInfo(); if (ctx.Header.ModelSignature != ver.ModelSignature) { using (var ch = h.Start("ModelCheck")) { ch.Info("The data model doesn't contain transforms."); } return(srcView); } ModelHeader.CheckVersionInfo(ref ctx.Header, ver); h.CheckValue(srcView, nameof(srcView)); h.CheckValue(isTransformTagAccepted, nameof(isTransformTagAccepted)); // *** Binary format *** // int: sizeof(Float) // int: number of transforms // foreach transform: (starting from version VersionAddedTags) // string: tag // string: args string int cbFloat = ctx.Reader.ReadInt32(); h.CheckDecode(cbFloat == sizeof(Float)); int cxf = ctx.Reader.ReadInt32(); h.CheckDecode(cxf >= 0); bool hasTags = ctx.Header.ModelVerReadable >= VersionAddedTags; var curView = srcView; for (int i = 0; i < cxf; i++) { string tag = ""; if (hasTags) { tag = ctx.LoadNonEmptyString(); ctx.LoadStringOrNull(); // ignore the args string } if (!isTransformTagAccepted(tag)) { continue; } IDataTransform xf; ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(env, out xf, string.Format(TransformDirTemplate, i), curView); curView = xf; } return(curView); }
public Bindings(ModelLoadContext ctx, DatabaseLoader parent) { Contracts.AssertValue(ctx); // *** Binary format *** // int: number of columns // foreach column: // int: id of column name // byte: DataKind // byte: bool of whether this is a key type // for a key type: // ulong: count for key range // int: number of segments // foreach segment: // string id: name // int: min // int: lim // byte: force vector (verWrittenCur: verIsVectorSupported) int cinfo = ctx.Reader.ReadInt32(); Contracts.CheckDecode(cinfo > 0); Infos = new ColInfo[cinfo]; for (int iinfo = 0; iinfo < cinfo; iinfo++) { string name = ctx.LoadNonEmptyString(); PrimitiveDataViewType itemType; var kind = (InternalDataKind)ctx.Reader.ReadByte(); Contracts.CheckDecode(Enum.IsDefined(typeof(InternalDataKind), kind)); bool isKey = ctx.Reader.ReadBoolByte(); if (isKey) { ulong count; Contracts.CheckDecode(KeyDataViewType.IsValidDataType(kind.ToType())); count = ctx.Reader.ReadUInt64(); Contracts.CheckDecode(0 < count); itemType = new KeyDataViewType(kind.ToType(), count); } else { itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind); } int cseg = ctx.Reader.ReadInt32(); Segment[] segs; if (cseg == 0) { segs = null; } else { Contracts.CheckDecode(cseg > 0); segs = new Segment[cseg]; for (int iseg = 0; iseg < cseg; iseg++) { string columnName = ctx.LoadStringOrNull(); int min = ctx.Reader.ReadInt32(); int lim = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 <= min && min < lim); bool forceVector = ctx.Reader.ReadBoolByte(); segs[iseg] = (columnName is null) ? new Segment(min, lim, forceVector) : new Segment(columnName, forceVector); } } // Note that this will throw if the segments are ill-structured, including the case // of multiple variable segments (since those segments will overlap and overlapping // segments are illegal). Infos[iinfo] = ColInfo.Create(name, itemType, segs, false); } OutputSchema = ComputeOutputSchema(); }
// Factory method for SignatureLoadModel. private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray())) { throw env.ExceptDecode(); } bool supportsMultiInputOutput = ctx.Header.ModelVerWritten > 0x00010001; var numInputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; env.CheckDecode(numInputs > 0); var inputs = new string[numInputs]; for (int j = 0; j < inputs.Length; j++) { inputs[j] = ctx.LoadNonEmptyString(); } var numOutputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; env.CheckDecode(numOutputs > 0); var outputs = new string[numOutputs]; for (int j = 0; j < outputs.Length; j++) { outputs[j] = ctx.LoadNonEmptyString(); } // Save custom-provided shapes. Those shapes overwrite shapes loaded from the ONNX model file. int customShapeInfosLength = ctx.Reader.ReadInt32(); // 0 means no custom shape. Non-zero means count of custom shapes. CustomShapeInfo[] loadedCustomShapeInfos = null; if (customShapeInfosLength > 0) { loadedCustomShapeInfos = new CustomShapeInfo[customShapeInfosLength]; for (int i = 0; i < customShapeInfosLength; ++i) { var name = ctx.LoadNonEmptyString(); var shape = ctx.Reader.ReadIntArray(); loadedCustomShapeInfos[i] = new CustomShapeInfo() { Name = name, Shape = shape }; } } int recursionLimit; // Recursion limit change if (ctx.Header.ModelVerWritten >= 0x00010003) { recursionLimit = ctx.Reader.ReadInt32(); } else { // Default if not written inside ONNX model recursionLimit = 100; } var options = new Options() { InputColumns = inputs, OutputColumns = outputs, CustomShapeInfos = loadedCustomShapeInfos, RecursionLimit = recursionLimit }; return(new OnnxTransformer(env, options, modelBytes)); }