/// <summary> /// The counter constructor of re-creating <see cref="MatrixFactorizationPredictionTransformer"/> from the context where /// the original transform is saved. /// </summary> public MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx) : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(MatrixFactorizationPredictionTransformer)), ctx) { // *** Binary format *** // <base info> // string: the column name of matrix's column ids. // string: the column name of matrix's row ids. MatrixColumnIndexColumnName = ctx.LoadString(); MatrixRowIndexColumnName = ctx.LoadString(); if (!TrainSchema.TryGetColumnIndex(MatrixColumnIndexColumnName, out int xCol)) { throw Host.ExceptSchemaMismatch(nameof(MatrixColumnIndexColumnName), RecommenderUtils.MatrixColumnIndexKind.Value, MatrixColumnIndexColumnName); } MatrixColumnIndexColumnType = TrainSchema.GetColumnType(xCol); if (!TrainSchema.TryGetColumnIndex(MatrixRowIndexColumnName, out int yCol)) { throw Host.ExceptSchemaMismatch(nameof(MatrixRowIndexColumnName), RecommenderUtils.MatrixRowIndexKind.Value, MatrixRowIndexColumnName); } MatrixRowIndexColumnType = TrainSchema.GetColumnType(yCol); BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); var schema = GetSchema(); var args = new GenericScorer.Arguments { Suffix = "" }; Scorer = new GenericScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); }
protected static KeyValuePair <RoleMappedSchema.ColumnRole, string>[] LoadBaseInfo( ModelLoadContext ctx, out string suffix) { // *** Binary format *** // int: id of the suffix // int: the number of input column roles // for each input column: // int: id of the column role // int: id of the column name suffix = ctx.LoadString(); var count = ctx.Reader.ReadInt32(); Contracts.CheckDecode(count >= 0); var columns = new KeyValuePair <RoleMappedSchema.ColumnRole, string> [count]; for (int i = 0; i < count; i++) { var role = ctx.LoadNonEmptyString(); var name = ctx.LoadNonEmptyString(); columns[i] = RoleMappedSchema.CreatePair(role, name); } return(columns); }
private TreeEnsembleFeaturizationTransformer(IHostEnvironment host, ModelLoadContext ctx) : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(TreeEnsembleFeaturizationTransformer)), ctx) { // *** Binary format *** // <base info> // string: feature column's name. // string: the name of the columns where tree prediction values are stored. // string: the name of the columns where trees' leave are stored. // string: the name of the columns where trees' paths are stored. // Load stored fields. string featureColumnName = ctx.LoadString(); _featureDetachedColumn = new DataViewSchema.DetachedColumn(TrainSchema[featureColumnName]); _treesColumnName = ctx.LoadStringOrNull(); _leavesColumnName = ctx.LoadStringOrNull(); _pathsColumnName = ctx.LoadStringOrNull(); // Create an argument to specify output columns' names of this transformer. _scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments { TreesColumnName = _treesColumnName, LeavesColumnName = _leavesColumnName, PathsColumnName = _pathsColumnName }; // Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce // a transformed IDataView. BindableMapper = new TreeEnsembleFeaturizerBindableMapper(host, _scorerArgs, Model); // Create a scorer. var roleMappedSchema = MakeFeatureRoleMappedSchema(TrainSchema); Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema); }
internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) { Host = host; ctx.LoadModel <TModel, SignatureLoadModel>(host, out TModel model, DirModel); Model = model; // *** Binary format *** // model: prediction model. // stream: empty data view that contains train schema. // id of string: feature column. // Clone the stream with the schema into memory. var ms = new MemoryStream(); ctx.TryLoadBinaryStream(DirTransSchema, reader => { reader.BaseStream.CopyTo(ms); }); ms.Position = 0; var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms); TrainSchema = loader.Schema; FeatureColumn = ctx.LoadString(); if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) { throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); } FeatureColumnType = TrainSchema.GetColumnType(col); BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); }
private void InitializationLogic(ModelLoadContext ctx, out float threshold, out string thresholdcolumn) { // *** Binary format *** // <base info> // float: scorer threshold // id of string: scorer threshold column threshold = ctx.Reader.ReadSingle(); thresholdcolumn = ctx.LoadString(); SetScorer(); }
internal AnomalyPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(AnomalyPredictionTransformer <TModel>)), ctx) { // *** Binary format *** // <base info> // float: scorer threshold // id of string: scorer threshold column Threshold = ctx.Reader.ReadSingle(); ThresholdColumn = ctx.LoadString(); SetScorer(); }
internal FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, ModelLoadContext ctx) : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), ctx) { // *** Binary format *** // <base info> // ids of strings: feature columns. // float: scorer threshold // id of string: scorer threshold column // count of feature columns. FAFM uses more than one. int featCount = Model.FieldCount; var featureColumns = new string[featCount]; var featureColumnTypes = new DataViewType[featCount]; for (int i = 0; i < featCount; i++) { featureColumns[i] = ctx.LoadString(); if (!TrainSchema.TryGetColumnIndex(featureColumns[i], out int col)) { throw Host.ExceptSchemaMismatch(nameof(FeatureColumns), "feature", featureColumns[i]); } featureColumnTypes[i] = TrainSchema[col].Type; } FeatureColumns = featureColumns; FeatureColumnTypes = featureColumnTypes; _threshold = ctx.Reader.ReadSingle(); _thresholdColumn = ctx.LoadString(); BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); var schema = GetSchema(); var args = new BinaryClassifierScorer.Arguments { Threshold = _threshold, ThresholdColumn = _thresholdColumn }; Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); }
// Factory for SignatureLoadModel. private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); var contractName = ctx.LoadString(); var composition = env.GetCompositionContainer(); ITransformer transformer = composition.GetExportedValue <ITransformer>(contractName); return(transformer); }
// Factory for SignatureLoadModel. private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); var contractName = ctx.LoadString(); if (ctx.Header.ModelVerWritten >= VerAssemblyNameSaved) { var contractAssembly = ctx.LoadString(); Assembly assembly = Assembly.Load(contractAssembly); env.ComponentCatalog.RegisterAssembly(assembly); } object factoryObject = env.ComponentCatalog.GetExtensionValue(env, typeof(CustomMappingFactoryAttributeAttribute), contractName); if (!(factoryObject is ICustomMappingFactory mappingFactory)) { throw env.Except($"The class with contract '{contractName}' must derive from '{typeof(CustomMappingFactory<,>).FullName}' or from '{typeof(StatefulCustomMappingFactory<,,>).FullName}'."); } return(mappingFactory.CreateTransformer(env, contractName)); }
private CustomStopWordsRemoverTransform(IHost host, ModelLoadContext ctx, IDataView input) : base(host, ctx, input, TestIsTextVector) { Host.AssertValue(ctx); using (var ch = Host.Start("Deserialization")) { // *** Binary format *** // <base> ch.AssertNonEmpty(Infos); const string dir = "Stopwords"; NormStr.Pool stopwrods = null; bool res = ctx.TryProcessSubModel(dir, c => { Host.CheckValue(c, nameof(ctx)); c.CheckAtModel(GetStopwrodsManagerVersionInfo()); // *** Binary format *** // int: number of stopwords // int[]: stopwords string ids int cstr = ctx.Reader.ReadInt32(); Host.CheckDecode(cstr > 0); stopwrods = new NormStr.Pool(); for (int istr = 0; istr < cstr; istr++) { var nstr = stopwrods.Add(ctx.LoadString()); Host.CheckDecode(nstr.Id == istr); } // All stopwords are distinct. Host.CheckDecode(stopwrods.Count == cstr); // The deserialized pool should not have the empty string. Host.CheckDecode(stopwrods.Get("") == null); }); if (!res) { throw Host.ExceptDecode(); } _stopWordsMap = stopwrods; ch.Done(); } Metadata.Seal(); }
public BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer <TModel>)), ctx) { // *** Binary format *** // <base info> // float: scorer threshold // id of string: scorer threshold column Threshold = ctx.Reader.ReadSingle(); ThresholdColumn = ctx.LoadString(); var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); }
// Factory method for SignatureLoadModel. private protected CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, string loaderSignature) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer <TICalibrator>))) { Contracts.AssertValue(ctx); _loaderSignature = loaderSignature; ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // model: _calibrator // scoreColumnName: _scoreColumnName ctx.LoadModel <TICalibrator, SignatureLoadModel>(env, out _calibrator, "Calibrator"); if (ctx.Header.ModelVerWritten >= 0x00010002) { _scoreColumnName = ctx.LoadString(); } else { _scoreColumnName = DefaultColumnNames.Score; } }
private ParquetPartitionedPathParser(IHost host, ModelLoadContext ctx) { Contracts.AssertValue(host); _host = host; _host.AssertValue(ctx); // ** Binary format ** // int: number of columns // foreach column: // string: column representation int numColumns = ctx.Reader.ReadInt32(); _host.CheckDecode(numColumns >= 0); _columns = new PartitionedFileLoader.Column[numColumns]; for (int i = 0; i < numColumns; i++) { var column = PartitionedFileLoader.Column.Parse(ctx.LoadString()); _host.CheckDecode(column != null); _columns[i] = column; } }