private DictCountTable(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { // *** Binary format *** // foreach of the _labelCardinality dictionaries // int: number N of elements in the dictionary. // for each of the N elements: // long: key // Single: value Tables = new Dictionary <long, float> [LabelCardinality]; for (int iTable = 0; iTable < LabelCardinality; iTable++) { Tables[iTable] = new Dictionary <long, float>(); int cnt = ctx.Reader.ReadInt32(); env.CheckDecode(cnt >= 0); for (int i = 0; i < cnt; i++) { long key = ctx.Reader.ReadInt64(); env.CheckDecode(!Tables[iTable].ContainsKey(key)); var value = ctx.Reader.ReadSingle(); env.CheckDecode(value >= 0); Tables[iTable].Add(key, value); } } }
public TransformInfo(IHostEnvironment env, ModelLoadContext ctx, string directoryName) { env.AssertValue(env); // *** Binary format *** // int: d (number of untransformed features) // int: NewDim (number of transformed features) // bool: UseSin // uint[4]: the seeds for the pseudo random number generator. SrcDim = ctx.Reader.ReadInt32(); NewDim = ctx.Reader.ReadInt32(); env.CheckDecode(NewDim > 0); _useSin = ctx.Reader.ReadBoolByte(); var length = ctx.Reader.ReadInt32(); env.CheckDecode(length == 4); _state = TauswortheHybrid.State.Load(ctx.Reader); _rand = new TauswortheHybrid(_state); env.CheckDecode(ctx.Repository != null && ctx.LoadModelOrNull <IFourierDistributionSampler, SignatureLoadModel>(env, out _matrixGenerator, directoryName)); // initialize the transform matrix int roundedUpD = RoundUp(NewDim, _cfltAlign); int roundedUpNumFeatures = RoundUp(SrcDim, _cfltAlign); RndFourierVectors = new AlignedArray(roundedUpD * roundedUpNumFeatures, CpuMathUtils.GetVectorAlignment()); RotationTerms = _useSin ? null : new AlignedArray(roundedUpD, CpuMathUtils.GetVectorAlignment()); InitializeFourierCoefficients(roundedUpNumFeatures, roundedUpD); }
public static BindingsImpl Create(ModelLoadContext ctx, DataViewSchema input, IHostEnvironment env, ISchemaBindableMapper bindable, Func <DataViewType, bool> outputTypeMatches, Func <DataViewType, ISchemaBoundRowMapper, DataViewType> getPredColType) { Contracts.AssertValue(env); env.AssertValue(ctx); // *** Binary format *** // <base info> // int: id of the scores column kind (metadata output) // int: id of the column used for deriving the predicted label column string suffix; var roles = LoadBaseInfo(ctx, out suffix); string scoreKind = ctx.LoadNonEmptyString(); string scoreCol = ctx.LoadNonEmptyString(); var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles)); var rowMapper = mapper as ISchemaBoundRowMapper; env.CheckParam(rowMapper != null, nameof(bindable), "Bindable expected to be an " + nameof(ISchemaBindableMapper) + "!"); // Find the score column of the mapper. int scoreColIndex; env.CheckDecode(mapper.OutputSchema.TryGetColumnIndex(scoreCol, out scoreColIndex)); var scoreType = mapper.OutputSchema[scoreColIndex].Type; env.CheckDecode(outputTypeMatches(scoreType)); var predColType = getPredColType(scoreType, rowMapper); return(new BindingsImpl(input, rowMapper, suffix, scoreKind, false, scoreColIndex, predColType)); }
private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen) { isFrozen = true; bool isNonFrozenModelSupported = ctx.Header.ModelVerReadable >= 0x00010002; if (isNonFrozenModelSupported) { isFrozen = ctx.Reader.ReadBoolByte(); } var numInputs = ctx.Reader.ReadInt32(); env.CheckDecode(numInputs > 0); inputs = new string[numInputs]; for (int j = 0; j < inputs.Length; j++) { inputs[j] = ctx.LoadNonEmptyString(); } bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002; var numOutputs = 1; if (isMultiOutput) { numOutputs = ctx.Reader.ReadInt32(); } env.CheckDecode(numOutputs > 0); outputs = new string[numOutputs]; for (int j = 0; j < outputs.Length; j++) { outputs[j] = ctx.LoadNonEmptyString(); } }
// Factory method for SignatureLoadModel. private static OnnxTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray())) throw env.ExceptDecode(); bool supportsMultiInputOutput = ctx.Header.ModelVerWritten > 0x00010001; var numInputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; env.CheckDecode(numInputs > 0); var inputs = new string[numInputs]; for (int j = 0; j < inputs.Length; j++) inputs[j] = ctx.LoadNonEmptyString(); var numOutputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; env.CheckDecode(numOutputs > 0); var outputs = new string[numOutputs]; for (int j = 0; j < outputs.Length; j++) outputs[j] = ctx.LoadNonEmptyString(); var args = new Arguments() { InputColumns = inputs, OutputColumns = outputs }; return new OnnxTransform(env, args, modelBytes); }
private CMCountTable(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, ctx) { // *** Binary format *** // int: depth // int: width // for each of the _labelCardinality tables: // for each of the _depth dictionaries // int: the number of pairs in the dictionary // for each pair: // int: index // float: value Depth = ctx.Reader.ReadInt32(); env.CheckDecode(Depth > 0); Width = ctx.Reader.ReadInt32(); env.CheckDecode(Width > 0); Tables = new Dictionary <int, float> [LabelCardinality][]; for (int i = 0; i < LabelCardinality; i++) { Tables[i] = new Dictionary <int, float> [Depth]; for (int j = 0; j < Depth; j++) { var count = ctx.Reader.ReadInt32(); Tables[i][j] = new Dictionary <int, float>(count); for (int k = 0; k < count; k++) { int index = ctx.Reader.ReadInt32(); float value = ctx.Reader.ReadSingle(); Tables[i][j].Add(index, value); } } } }
// Factory method for SignatureLoadModel. private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray())) { throw env.ExceptDecode(); } bool supportsMultiInputOutput = ctx.Header.ModelVerWritten > 0x00010001; var numInputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; env.CheckDecode(numInputs > 0); var inputs = new string[numInputs]; for (int j = 0; j < inputs.Length; j++) { inputs[j] = ctx.LoadNonEmptyString(); } var numOutputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; env.CheckDecode(numOutputs > 0); var outputs = new string[numOutputs]; for (int j = 0; j < outputs.Length; j++) { outputs[j] = ctx.LoadNonEmptyString(); } // Save custom-provided shapes. Those shapes overwrite shapes loaded from the ONNX model file. int customShapeInfosLength = ctx.Reader.ReadInt32(); // 0 means no custom shape. Non-zero means count of custom shapes. CustomShapeInfo[] loadedCustomShapeInfos = null; if (customShapeInfosLength > 0) { loadedCustomShapeInfos = new CustomShapeInfo[customShapeInfosLength]; for (int i = 0; i < customShapeInfosLength; ++i) { var name = ctx.LoadNonEmptyString(); var shape = ctx.Reader.ReadIntArray(); loadedCustomShapeInfos[i] = new CustomShapeInfo() { Name = name, Shape = shape }; } } var options = new Options() { InputColumns = inputs, OutputColumns = outputs, CustomShapeInfos = loadedCustomShapeInfos }; return(new OnnxTransformer(env, options, modelBytes)); }
private static ExpressionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // int: number of output columns // for each output column: // int: number of inputs // foreach input // int: Id of the input column name // int: Id of the expression // int: Id of the output column name // int: the index of the vector input (or -1) // int[]: The data kinds of the input columns var columnCount = ctx.Reader.ReadInt32(); env.CheckDecode(columnCount > 0); var columns = new ColumnInfo[columnCount]; for (int i = 0; i < columnCount; i++) { var inputSize = ctx.Reader.ReadInt32(); env.CheckDecode(inputSize >= 0); var inputColumnNames = new string[inputSize]; for (int j = 0; j < inputSize; j++) { inputColumnNames[j] = ctx.LoadNonEmptyString(); } var expression = ctx.LoadNonEmptyString(); var outputColumnName = ctx.LoadNonEmptyString(); var vectorInputColumn = ctx.Reader.ReadInt32(); env.CheckDecode(vectorInputColumn >= -1); var inputTypes = new DataViewType[inputSize]; for (int j = 0; j < inputSize; j++) { var dataKindIndex = ctx.Reader.ReadInt32(); var kind = InternalDataKindExtensions.FromIndex(dataKindIndex); inputTypes[j] = ColumnTypeExtensions.PrimitiveTypeFromKind(kind); } var node = ExpressionEstimator.ParseAndBindLambda(env, expression, vectorInputColumn, inputTypes, out var perm); columns[i] = new ColumnInfo(env, inputColumnNames, inputTypes, expression, outputColumnName, vectorInputColumn, node, perm); } return(new ExpressionTransformer(env, columns)); }
private RandomNumberGenerator(IHostEnvironment env, ModelLoadContext ctx) { Contracts.AssertValue(env); env.AssertValue(ctx); // *** Binary format *** // int: sizeof(Float) // Float: gamma int cbFloat = ctx.Reader.ReadInt32(); env.CheckDecode(cbFloat == sizeof(float)); _gamma = ctx.Reader.ReadFloat(); env.CheckDecode(FloatUtils.IsFinite(_gamma)); }
/// <summary> /// Loads data view (loader and transforms) from <paramref name="rep"/> if <paramref name="loadTransforms"/> is set to true, /// otherwise loads loader only. /// </summary> public static IDataLoader LoadLoader(IHostEnvironment env, RepositoryReader rep, IMultiStreamSource files, bool loadTransforms) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(rep, nameof(rep)); env.CheckValue(files, nameof(files)); IDataLoader loader; // If loadTransforms is false, load the loader only, not the transforms. Repository.Entry ent = null; string dir = ""; if (!loadTransforms) { ent = rep.OpenEntryOrNull(dir = Path.Combine(DirDataLoaderModel, "Loader"), ModelLoadContext.ModelStreamName); } if (ent == null) // either loadTransforms is true, or it's not a composite loader { ent = rep.OpenEntry(dir = DirDataLoaderModel, ModelLoadContext.ModelStreamName); } env.CheckDecode(ent != null, "Loader is not found."); env.AssertNonEmpty(dir); using (ent) { env.Assert(ent.Stream.Position == 0); ModelLoadContext.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out loader, rep, ent, dir, files); } return(loader); }
public LinearModelStatistics(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); _env = env; _env.AssertValue(ctx); // *** Binary Format *** // int: count of parameters // long: count of training examples // Single: deviance // Single: null deviance // bool: whether standard error is included // (Conditional) Single[_paramCount]: values of std errors of coefficients // (Conditional) int: length of std errors of coefficients // (Conditional) int[_paramCount]: indices of std errors of coefficients _paramCount = ctx.Reader.ReadInt32(); _env.CheckDecode(_paramCount > 0); _trainingExampleCount = ctx.Reader.ReadInt64(); _env.CheckDecode(_trainingExampleCount > 0); _deviance = ctx.Reader.ReadFloat(); _nullDeviance = ctx.Reader.ReadFloat(); var hasStdErrors = ctx.Reader.ReadBoolean(); if (!hasStdErrors) { _env.Assert(_coeffStdError == null); return; } Single[] stdErrorValues = ctx.Reader.ReadFloatArray(_paramCount); int length = ctx.Reader.ReadInt32(); _env.CheckDecode(length >= _paramCount); if (length == _paramCount) { _coeffStdError = new VBuffer <Single>(length, stdErrorValues); return; } _env.Assert(length > _paramCount); int[] stdErrorIndices = ctx.Reader.ReadIntArray(_paramCount); _coeffStdError = new VBuffer <Single>(length, _paramCount, stdErrorValues, stdErrorIndices); }
// Factory method for SignatureLoadModel. private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // stream: tensorFlow model. // int: number of input columns // for each input column // int: id of int column name // int: number of output columns // for each output column // int: id of output column name byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) { throw env.ExceptDecode(); } var session = TensorFlowUtils.LoadTFSession(env, modelBytes); var numInputs = ctx.Reader.ReadInt32(); env.CheckDecode(numInputs > 0); string[] inputs = new string[numInputs]; for (int j = 0; j < inputs.Length; j++) { inputs[j] = ctx.LoadNonEmptyString(); } bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002; var numOutputs = 1; if (isMultiOutput) { numOutputs = ctx.Reader.ReadInt32(); } env.CheckDecode(numOutputs > 0); var outputs = new string[numOutputs]; for (int j = 0; j < outputs.Length; j++) { outputs[j] = ctx.LoadNonEmptyString(); } return(new TensorFlowTransform(env, session, inputs, outputs)); }
internal PredictorModelImpl(IHostEnvironment env, Stream stream) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(stream, nameof(stream)); using (var ch = env.Start("Loading predictor model")) { // REVIEW: address the asymmetry in the way we're loading and saving the model. TransformModel = new TransformModelImpl(env, stream); var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, stream); env.CheckDecode(roles != null, "Predictor model must contain role mappings"); _roleMappings = roles.ToArray(); Predictor = ModelFileUtils.LoadPredictorOrNull(env, stream); env.CheckDecode(Predictor != null, "Predictor model must contain a predictor"); } }
protected CountTableBase(IHostEnvironment env, string name, ModelLoadContext ctx) { env.AssertNonWhiteSpace(name); env.AssertValue(ctx); // *** Binary format *** // int: label cardinality // double[]: prior frequencies // float: garbage threshold // float[]: garbage counts LabelCardinality = ctx.Reader.ReadInt32(); env.CheckDecode(0 < LabelCardinality && LabelCardinality < LabelCardinalityLim); _priorFrequencies = ctx.Reader.ReadDoubleArray(); env.CheckDecode(Utils.Size(_priorFrequencies) == LabelCardinality); env.CheckDecode(_priorFrequencies.All(x => x >= 0)); GarbageThreshold = ctx.Reader.ReadSingle(); env.CheckDecode(GarbageThreshold >= 0); _garbageCounts = ctx.Reader.ReadSingleArray(); if (GarbageThreshold == 0) { env.CheckDecode(Utils.Size(_garbageCounts) == 0); } else { env.CheckDecode(Utils.Size(_garbageCounts) == LabelCardinality); env.CheckDecode(_garbageCounts.All(x => x >= 0)); } }
private protected BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx) { Contracts.AssertValue(env); env.AssertNonWhiteSpace(name); Host = env.Register(name); Host.AssertValue(ctx); // *** Binary format *** // int: sizeof(Single) // Float: _validationDatasetProportion int cbFloat = ctx.Reader.ReadInt32(); env.CheckDecode(cbFloat == sizeof(Single)); ValidationDatasetProportion = ctx.Reader.ReadFloat(); env.CheckDecode(0 <= ValidationDatasetProportion && ValidationDatasetProportion < 1); ctx.LoadModel <IPredictorProducing <TOutput>, SignatureLoadModel>(env, out Meta, "MetaPredictor"); CheckMeta(); }
internal ImplRaw(ModelLoadContext ctx, IHostEnvironment env) { // labelType GuessLabelType(); int[] indices = ctx.Reader.ReadIntArray(); TLabel[] classes; if (LabelType == NumberDataViewType.Single) { classes = ctx.Reader.ReadFloatArray() as TLabel[]; env.CheckValue(classes, "classes"); } else if (LabelType == NumberDataViewType.Byte) { classes = ctx.Reader.ReadByteArray() as TLabel[]; env.CheckValue(classes, "classes"); } else if (LabelType == NumberDataViewType.UInt16) { var val = ctx.Reader.ReadUIntArray(); env.CheckValue(val, "classes"); classes = val.Select(c => (ushort)c).ToArray() as TLabel[]; } else if (LabelType == NumberDataViewType.UInt32) { var val = ctx.Reader.ReadUIntArray(); env.CheckValue(val, "classes"); classes = val as TLabel[]; } else { throw env.Except("Unexpected type for LabelType."); } _classes = new VBuffer <TLabel>(classes.Length, classes, indices); _singleColumn = ctx.Reader.ReadInt32() == 1; _labelKey = ctx.Reader.ReadInt32() == 1; FinalizeOutputType(); int len = ctx.Reader.ReadInt32(); env.CheckDecode(len > 0); var predictors = new TScalarPredictor[len]; IPredictor reclassPredictor; LoadPredictors(env, predictors, out reclassPredictor, ctx); Preparation(predictors, reclassPredictor); var checkCode = ctx.Reader.ReadByte(); if (checkCode != 213) { throw Contracts.Except("CheckCode is wrong. Serialization failed."); } }
internal ModelStatisticsBase(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); Env = env; Env.AssertValue(ctx); // *** Binary Format *** // int: count of parameters // long: count of training examples // float: deviance // float: null deviance ParametersCount = ctx.Reader.ReadInt32(); Env.CheckDecode(ParametersCount > 0); TrainingExampleCount = ctx.Reader.ReadInt64(); Env.CheckDecode(TrainingExampleCount > 0); Deviance = ctx.Reader.ReadFloat(); NullDeviance = ctx.Reader.ReadFloat(); }
// Factory method for SignatureLoadModel. private static PcaTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(nameof(PcaTransformer)); host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); if (ctx.Header.ModelVerWritten == 0x00010001) { int cbFloat = ctx.Reader.ReadInt32(); env.CheckDecode(cbFloat == sizeof(float)); } return(new PcaTransformer(host, ctx)); }
public static Bindings Create(IHostEnvironment env, ModelLoadContext ctx, Schema input, OptionalColumnTransform parent) { Contracts.AssertValue(ctx); Contracts.AssertValue(input); // *** Binary format *** // Schema of the data view containing the optional columns // int: number of added columns // for each added column // int: id of output column name // ColumnType: the type of the column byte[] buffer = null; if (!ctx.TryLoadBinaryStream("Schema.idv", r => buffer = r.ReadByteArray())) { throw env.ExceptDecode(); } BinaryLoader loader = null; var strm = new MemoryStream(buffer, writable: false); loader = new BinaryLoader(env, new BinaryLoader.Arguments(), strm); int size = ctx.Reader.ReadInt32(); Contracts.CheckDecode(size > 0); var saver = new BinarySaver(env, new BinarySaver.Arguments()); var names = new string[size]; var columnTypes = new ColumnType[size]; var srcCols = new int[size]; var srcColsWithOptionalColumn = new int[size]; for (int i = 0; i < size; i++) { names[i] = ctx.LoadNonEmptyString(); columnTypes[i] = saver.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream); int col; bool success = input.TryGetColumnIndex(names[i], out col); srcCols[i] = success ? col : -1; success = loader.Schema.TryGetColumnIndex(names[i], out var colWithOptionalColumn); env.CheckDecode(success); srcColsWithOptionalColumn[i] = colWithOptionalColumn; } return(new Bindings(parent, columnTypes, srcCols, srcColsWithOptionalColumn, input, loader.Schema, false, names)); }
/// <summary> /// Back-compatibility function that handles loading the DropColumns Transform. /// </summary> private static ColumnSelectingTransformer LoadDropColumnsTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { // *** Binary format *** // int: sizeof(Float) // bindings int cbFloat = ctx.Reader.ReadInt32(); env.CheckDecode(cbFloat == sizeof(float)); // *** Binary format *** // bool: whether to keep (vs drop) the named columns // int: number of names // int[]: the ids of the names var keep = ctx.Reader.ReadBoolByte(); int count = ctx.Reader.ReadInt32(); Contracts.CheckDecode(count > 0); var names = new HashSet <string>(); for (int i = 0; i < count; i++) { string name = ctx.LoadNonEmptyString(); Contracts.CheckDecode(names.Add(name)); } string[] keepColumns = null; string[] dropColumns = null; if (keep) { keepColumns = names.ToArray(); } else { dropColumns = names.ToArray(); } // Note for backward compatibility, Drop/Keep Columns always preserves // hidden columns return(new ColumnSelectingTransformer(env, keepColumns, dropColumns, true)); }
public TypeName(IHostEnvironment env, float p, int foo) { Contracts.CheckValue(env, nameof(env)); env.CheckParam(0 <= p && p <= 1, nameof(p), "Should be in range [0,1]"); env.CheckParam(0 <= p && p <= 1, "p"); // Should fail. env.CheckParam(0 <= p && p <= 1, nameof(p) + nameof(p)); // Should fail. env.CheckValue(paramName: nameof(p), val: "p"); // Should succeed despite confusing order. env.CheckValue(paramName: "p", val: nameof(p)); // Should fail despite confusing order. env.CheckValue("p", nameof(p)); env.CheckUserArg(foo > 5, "foo", "Nice"); env.CheckUserArg(foo > 5, nameof(foo), "Nice"); env.Except(); // Not throwing or doing anything with the exception, so should fail. Contracts.ExceptParam(nameof(env), "What a silly env"); // Should also fail. if (false) { throw env.Except(); // Should not fail. } if (false) { throw env.ExceptParam(nameof(env), "What a silly env"); // Should not fail. } if (false) { throw env.ExceptParam("env", "What a silly env"); // Should fail due to name error. } var e = env.Except(); env.Check(true, $"Hello {foo} is cool"); env.Check(true, "Hello it is cool"); string coolMessage = "Hello it is cool"; env.Check(true, coolMessage); env.Check(true, string.Format("Hello {0} is cool", foo)); env.Check(true, Messages.CoolMessage); env.CheckDecode(true, "Not suspicious, no ModelLoadContext"); Contracts.Check(true, "Fine: " + nameof(env)); Contracts.Check(true, "Less fine: " + env.GetType().Name); Contracts.CheckUserArg(0 <= p && p <= 1, "p", "On a new line"); }