private TreeEnsembleFeaturizationTransformer(IHostEnvironment host, ModelLoadContext ctx) : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(TreeEnsembleFeaturizationTransformer)), ctx) { // *** Binary format *** // <base info> // string: feature column's name. // string: the name of the columns where tree prediction values are stored. // string: the name of the columns where trees' leave are stored. // string: the name of the columns where trees' paths are stored. // Load stored fields. string featureColumnName = ctx.LoadString(); _featureDetachedColumn = new DataViewSchema.DetachedColumn(TrainSchema[featureColumnName]); _treesColumnName = ctx.LoadStringOrNull(); _leavesColumnName = ctx.LoadStringOrNull(); _pathsColumnName = ctx.LoadStringOrNull(); // Create an argument to specify output columns' names of this transformer. _scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments { TreesColumnName = _treesColumnName, LeavesColumnName = _leavesColumnName, PathsColumnName = _pathsColumnName }; // Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce // a transformed IDataView. BindableMapper = new TreeEnsembleFeaturizerBindableMapper(host, _scorerArgs, Model); // Create a scorer. var roleMappedSchema = MakeFeatureRoleMappedSchema(TrainSchema); Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema); }
public static FeatureNameCollection Create(ModelLoadContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); ctx.CheckVersionInfo(GetVersionInfo()); // *** Binary format *** // int: number of features (size) // int: number of indices (0 if dense) // int[]: indices (if sparse) // int[]: ids of names (matches either number of features or number of indices var size = ctx.Reader.ReadInt32(); Contracts.CheckDecode(size >= 0); var isize = ctx.Reader.ReadInt32(); Contracts.CheckDecode(isize >= -1); if (isize < 0) { // Dense case var names = new string[size]; for (int i = 0; i < size; i++) { var name = ctx.LoadStringOrNull(); names[i] = string.IsNullOrEmpty(name) ? null : name; } return(Create(size, names)); } var dict = new Dictionary <int, string>(); var indices = new int[isize]; var prev = -1; for (int ii = 0; ii < isize; ii++) { indices[ii] = ctx.Reader.ReadInt32(); Contracts.CheckDecode(prev < indices[ii]); prev = indices[ii]; } Contracts.CheckDecode(prev < size); for (int ii = 0; ii < isize; ii++) { var name = ctx.LoadStringOrNull(); if (!string.IsNullOrEmpty(name)) { dict.Add(indices[ii], name); } } return(Create(size, dict)); }
public ColInfoEx(ModelLoadContext ctx, ISchema input) { Contracts.AssertValue(ctx); Contracts.AssertValue(input); // *** Binary format *** // int: the stopwords list language // int: the id of languages column name Lang = (Language)ctx.Reader.ReadInt32(); Contracts.CheckDecode(Enum.IsDefined(typeof(Language), Lang)); if (Lang == Language.Norwegian_Bokmal && ctx.Header.ModelVerWritten == 0x00010001) { Lang = Language.Norwegian_Bokmal_v1; } _langsColName = ctx.LoadStringOrNull(); if (_langsColName != null) { Bind(input, _langsColName, out LangsColIndex, false); Contracts.Assert(LangsColIndex >= 0); } else { LangsColIndex = -1; } }
/// <summary> /// Loads all transforms from the <paramref name="ctx"/> that pass the <paramref name="isTransformTagAccepted"/> test, /// applies them sequentially to the <paramref name="srcLoader"/>, and returns the (composite) data loader. /// </summary> private static IDataLoader LoadTransforms(ModelLoadContext ctx, IDataLoader srcLoader, IHost host, Func <string, bool> isTransformTagAccepted) { Contracts.AssertValue(host, "host"); host.AssertValue(srcLoader); host.AssertValue(ctx); // *** Binary format *** // int: sizeof(Float) // int: number of transforms // foreach transform: (starting from version VersionAddedTags) // string: tag // string: args string int cbFloat = ctx.Reader.ReadInt32(); host.CheckDecode(cbFloat == sizeof(Float)); int cxf = ctx.Reader.ReadInt32(); host.CheckDecode(cxf >= 0); bool hasTags = ctx.Header.ModelVerReadable >= VersionAddedTags; var tagData = new List <KeyValuePair <string, string> >(); var acceptedIds = new List <int>(); for (int i = 0; i < cxf; i++) { string tag = ""; string argsString = null; if (hasTags) { tag = ctx.LoadNonEmptyString(); argsString = ctx.LoadStringOrNull(); } if (!isTransformTagAccepted(tag)) { continue; } acceptedIds.Add(i); tagData.Add(new KeyValuePair <string, string>(tag, argsString)); } host.Assert(tagData.Count == acceptedIds.Count); if (tagData.Count == 0) { return(srcLoader); } return(ApplyTransformsCore(host, srcLoader, tagData.ToArray(), (h, index, data) => { IDataTransform xf; ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(host, out xf, string.Format(TransformDirTemplate, acceptedIds[index]), data); return xf; })); }
private void InitializationLogic(ModelLoadContext ctx, out string trainLabelColumn) { // *** Binary format *** // <base info> // id of string: train label column trainLabelColumn = ctx.LoadStringOrNull(); SetScorer(); }
public static Bindings Create(ModelLoadContext ctx, ISchema input, IChannel ch) { Contracts.AssertValue(ch); ch.AssertValue(ctx); // *** Binary format *** // int: count of group column infos (ie, count of source columns) // For each group column info // int: the tokenizer language // int: the id of source column name // int: the id of languages column name // bool: whether the types output is required // For each column info that belongs to this group column info // (either one column info for tokens or two for tokens and types) // int: the id of the column name int groupsLen = ctx.Reader.ReadInt32(); ch.CheckDecode(groupsLen > 0); var names = new List<string>(); var infos = new List<ColInfo>(); var groups = new ColGroupInfo[groupsLen]; for (int i = 0; i < groups.Length; i++) { int lang = ctx.Reader.ReadInt32(); ch.CheckDecode(Enum.IsDefined(typeof(Language), lang)); string srcName = ctx.LoadNonEmptyString(); int srcIdx; ColumnType srcType; Bind(input, srcName, t => t.ItemType.IsText, SrcTypeName, out srcIdx, out srcType, false); string langsName = ctx.LoadStringOrNull(); int langsIdx; if (langsName != null) { ColumnType langsType; Bind(input, langsName, t => t.IsText, LangTypeName, out langsIdx, out langsType, false); } else langsIdx = -1; bool requireTypes = ctx.Reader.ReadBoolByte(); groups[i] = new ColGroupInfo((Language)lang, srcIdx, srcName, srcType, langsIdx, langsName, requireTypes); infos.Add(new ColInfo(i)); names.Add(ctx.LoadNonEmptyString()); if (requireTypes) { infos.Add(new ColInfo(i, isTypes: true)); names.Add(ctx.LoadNonEmptyString()); } } return new Bindings(groups, infos.ToArray(), input, false, names.ToArray()); }
public MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer <TModel>)), ctx) { // *** Binary format *** // <base info> // id of string: train label column _trainLabelColumn = ctx.LoadStringOrNull(); SetScorer(); }
/// <summary> /// Loads all transforms from the <paramref name="ctx"/> that pass the <paramref name="isTransformTagAccepted"/> test, /// applies them sequentially to the <paramref name="srcView"/>, and returns the resulting data view. /// If there are no transforms in <paramref name="ctx"/> that are accepted, returns the original <paramref name="srcView"/>. /// The difference from the <c>Create</c> method above is that: /// - it doesn't wrap the results into a loader, just returns the last transform in the chain. /// - it accepts <see cref="IDataView"/> as input, not necessarily a loader. /// - it throws away the tag information. /// - it doesn't throw if the context is not representing a <see cref="LegacyCompositeDataLoader"/>: in this case it's assumed that no transforms /// meet the test, and the <paramref name="srcView"/> is returned. /// Essentially, this is a helper method for the LoadTransform class. /// </summary> public static IDataView LoadSelectedTransforms(ModelLoadContext ctx, IDataView srcView, IHostEnvironment env, Func<string, bool> isTransformTagAccepted) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); h.CheckValue(ctx, nameof(ctx)); h.Check(ctx.Reader.BaseStream.Position == ctx.FpMin + ctx.Header.FpModel); var ver = GetVersionInfo(); if (ctx.Header.ModelSignature != ver.ModelSignature) { using (var ch = h.Start("ModelCheck")) { ch.Info("The data model doesn't contain transforms."); } return srcView; } ModelHeader.CheckVersionInfo(ref ctx.Header, ver); h.CheckValue(srcView, nameof(srcView)); h.CheckValue(isTransformTagAccepted, nameof(isTransformTagAccepted)); // *** Binary format *** // int: sizeof(Float) // int: number of transforms // foreach transform: (starting from version VersionAddedTags) // string: tag // string: args string int cbFloat = ctx.Reader.ReadInt32(); h.CheckDecode(cbFloat == sizeof(float)); int cxf = ctx.Reader.ReadInt32(); h.CheckDecode(cxf >= 0); bool hasTags = ctx.Header.ModelVerReadable >= VersionAddedTags; var curView = srcView; for (int i = 0; i < cxf; i++) { string tag = ""; if (hasTags) { tag = ctx.LoadNonEmptyString(); ctx.LoadStringOrNull(); // ignore the args string } if (!isTransformTagAccepted(tag)) continue; IDataTransform xf; ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out xf, string.Format(TransformDirTemplate, i), curView); curView = xf; } return curView; }
private void InitializationLogic(ModelLoadContext ctx, out string trainLabelColumn, out string scoreColumn, out string predictedLabelColumn) { // *** Binary format *** // <base info> // id of string: train label column trainLabelColumn = ctx.LoadStringOrNull(); if (ctx.Header.ModelVerWritten >= 0x00010002) { scoreColumn = ctx.LoadStringOrNull(); predictedLabelColumn = ctx.LoadStringOrNull(); } else { scoreColumn = AnnotationUtils.Const.ScoreValueKind.Score; predictedLabelColumn = DefaultColumnNames.PredictedLabel; } SetScorer(); }
public static Bindings Create(OneToOneTransformBase parent, ModelLoadContext ctx, DataViewSchema inputSchema, ITransposeDataView transposeInput, Func <DataViewType, string> testType) { Contracts.AssertValue(parent); var host = parent.Host; host.CheckValue(ctx, nameof(ctx)); host.AssertValue(inputSchema); host.AssertValueOrNull(transposeInput); host.AssertValueOrNull(testType); // *** Binary format *** // int: number of added columns // for each added column // int: id of output column name // int: id of input column name int cinfo = ctx.Reader.ReadInt32(); host.CheckDecode(cinfo > 0); var names = new string[cinfo]; var infos = new ColInfo[cinfo]; for (int i = 0; i < cinfo; i++) { string dst = ctx.LoadNonEmptyString(); names[i] = dst; // Note that in old files, the source name may be null indicating that // the source column has the same name as the added column. string tmp = ctx.LoadStringOrNull(); string src = tmp ?? dst; host.CheckDecode(!string.IsNullOrEmpty(src)); int colSrc; if (!inputSchema.TryGetColumnIndex(src, out colSrc)) { throw host.ExceptSchemaMismatch(nameof(inputSchema), "source", src); } var type = inputSchema[colSrc].Type; if (testType != null) { string reason = testType(type); if (reason != null) { throw host.Except(InvalidTypeErrorFormat, src, type, reason); } } var slotType = transposeInput?.GetSlotType(i); infos[i] = new ColInfo(dst, colSrc, type, slotType as VectorDataViewType); } return(new Bindings(parent, infos, inputSchema, false, names)); }
private protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx) : base(env, name, ctx) { // *** Binary format *** // int: model count // int: weight count (0 or model count) // Float[]: weights // for each model: // int: number of SelectedFeatures (in bits) // byte[]: selected features (as many as needed for number of bits == (numSelectedFeatures + 7) / 8) // int: number of Metric values // for each Metric: // Float: metric value // int: metric name (id of the metric name in the string table) // in version 0x0001x0002: // bool: is the metric averaged int count = ctx.Reader.ReadInt32(); Host.CheckDecode(count > 0); int weightCount = ctx.Reader.ReadInt32(); Host.CheckDecode(weightCount == 0 || weightCount == count); Weights = ctx.Reader.ReadFloatArray(weightCount); Models = new FeatureSubsetModel <TOutput> [count]; var ver = ctx.Header.ModelVerWritten; for (int i = 0; i < count; i++) { ctx.LoadModel <IPredictor, SignatureLoadModel>(Host, out IPredictor p, string.Format(SubPredictorFmt, i)); var predictor = p as IPredictorProducing <TOutput>; Host.Check(predictor != null, "Inner predictor type not compatible with the ensemble type."); var features = ctx.Reader.ReadBitArray(); int numMetrics = ctx.Reader.ReadInt32(); Host.CheckDecode(numMetrics >= 0); var metrics = new KeyValuePair <string, double> [numMetrics]; for (int j = 0; j < numMetrics; j++) { var metricValue = ctx.Reader.ReadFloat(); var metricName = ctx.LoadStringOrNull(); if (ver == VerOld) { ctx.Reader.ReadBoolByte(); } metrics[j] = new KeyValuePair <string, double>(metricName, metricValue); } Models[i] = new FeatureSubsetModel <TOutput>(predictor, features, metrics); } ctx.LoadModel <IOutputCombiner <TOutput>, SignatureLoadModel>(Host, out Combiner, @"Combiner"); }
private protected SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx, TModel model) : base(host, ctx, model) { FeatureColumnName = ctx.LoadStringOrNull(); if (FeatureColumnName == null) FeatureColumnType = null; else if (!TrainSchema.TryGetColumnIndex(FeatureColumnName, out int col)) throw Host.ExceptSchemaMismatch(nameof(FeatureColumnName), "feature", FeatureColumnName); else FeatureColumnType = TrainSchema[col].Type; BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, ModelAsPredictor); }
public MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer <TModel>)), ctx) { // *** Binary format *** // <base info> // id of string: train label column _trainLabelColumn = ctx.LoadStringOrNull(); var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumn); var args = new MultiClassClassifierScorer.Arguments(); _scorer = new MultiClassClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); }
protected PerInstanceEvaluatorBase(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) { Host = env.Register("PerInstanceRowMapper"); // *** Binary format ** // int: Id of the score column name // int: Id of the label column name ScoreCol = ctx.LoadNonEmptyString(); LabelCol = ctx.LoadStringOrNull(); if (!string.IsNullOrEmpty(LabelCol) && !schema.TryGetColumnIndex(LabelCol, out LabelIndex)) { throw Host.Except($"Did not find the label column '{LabelCol}'"); } if (!schema.TryGetColumnIndex(ScoreCol, out ScoreIndex)) { throw Host.Except($"Did not find column '{ScoreCol}'"); } }
internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx) : base(host, ctx) { FeatureColumn = ctx.LoadStringOrNull(); if (FeatureColumn == null) { FeatureColumnType = null; } else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) { throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); } else { FeatureColumnType = TrainSchema[col].Type; } BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); }
internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) { Host = host; ctx.LoadModel <TModel, SignatureLoadModel>(host, out TModel model, DirModel); Model = model; // *** Binary format *** // model: prediction model. // stream: empty data view that contains train schema. // id of string: feature column. // Clone the stream with the schema into memory. var ms = new MemoryStream(); ctx.TryLoadBinaryStream(DirTransSchema, reader => { reader.BaseStream.CopyTo(ms); }); ms.Position = 0; var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms); TrainSchema = loader.Schema; FeatureColumn = ctx.LoadStringOrNull(); if (FeatureColumn == null) { FeatureColumnType = null; } else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) { throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); } else { FeatureColumnType = TrainSchema.GetColumnType(col); } BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); }
public TreeEnsemble(ModelLoadContext ctx, bool usingDefaultValues, bool categoricalSplits) { // REVIEW: Verify the contents of the ensemble, both during building, // and during deserialization. // *** Binary format *** // int: Number of trees // Regression trees (num trees of these) // double: Bias // int: Id to InputInitializationContent string, currently ignored _trees = new List <RegressionTree>(); int numTrees = ctx.Reader.ReadInt32(); Contracts.CheckDecode(numTrees >= 0); for (int t = 0; t < numTrees; ++t) { AddTree(RegressionTree.Load(ctx, usingDefaultValues, categoricalSplits)); } Bias = ctx.Reader.ReadDouble(); _firstInputInitializationContent = ctx.LoadStringOrNull(); }
public Bindings(ModelLoadContext ctx, DatabaseLoader parent) { Contracts.AssertValue(ctx); // *** Binary format *** // int: number of columns // foreach column: // int: id of column name // byte: DataKind // byte: bool of whether this is a key type // for a key type: // ulong: count for key range // int: number of segments // foreach segment: // string id: name // int: min // int: lim // byte: force vector (verWrittenCur: verIsVectorSupported) int cinfo = ctx.Reader.ReadInt32(); Contracts.CheckDecode(cinfo > 0); Infos = new ColInfo[cinfo]; for (int iinfo = 0; iinfo < cinfo; iinfo++) { string name = ctx.LoadNonEmptyString(); PrimitiveDataViewType itemType; var kind = (InternalDataKind)ctx.Reader.ReadByte(); Contracts.CheckDecode(Enum.IsDefined(typeof(InternalDataKind), kind)); bool isKey = ctx.Reader.ReadBoolByte(); if (isKey) { ulong count; Contracts.CheckDecode(KeyDataViewType.IsValidDataType(kind.ToType())); count = ctx.Reader.ReadUInt64(); Contracts.CheckDecode(0 < count); itemType = new KeyDataViewType(kind.ToType(), count); } else { itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind); } int cseg = ctx.Reader.ReadInt32(); Segment[] segs; if (cseg == 0) { segs = null; } else { Contracts.CheckDecode(cseg > 0); segs = new Segment[cseg]; for (int iseg = 0; iseg < cseg; iseg++) { string columnName = ctx.LoadStringOrNull(); int min = ctx.Reader.ReadInt32(); int lim = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 <= min && min < lim); bool forceVector = ctx.Reader.ReadBoolByte(); segs[iseg] = (columnName is null) ? new Segment(min, lim, forceVector) : new Segment(columnName, forceVector); } } // Note that this will throw if the segments are ill-structured, including the case // of multiple variable segments (since those segments will overlap and overlapping // segments are illegal). Infos[iinfo] = ColInfo.Create(name, itemType, segs, false); } OutputSchema = ComputeOutputSchema(); }