public static BindingsImpl Create(ModelLoadContext ctx, DataViewSchema input, IHostEnvironment env, ISchemaBindableMapper bindable, Func <DataViewType, bool> outputTypeMatches, Func <DataViewType, ISchemaBoundRowMapper, DataViewType> getPredColType) { Contracts.AssertValue(env); env.AssertValue(ctx); // *** Binary format *** // <base info> // int: id of the scores column kind (metadata output) // int: id of the column used for deriving the predicted label column string suffix; var roles = LoadBaseInfo(ctx, out suffix); string scoreKind = ctx.LoadNonEmptyString(); string scoreCol = ctx.LoadNonEmptyString(); var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles)); var rowMapper = mapper as ISchemaBoundRowMapper; env.CheckParam(rowMapper != null, nameof(bindable), "Bindable expected to be an " + nameof(ISchemaBindableMapper) + "!"); // Find the score column of the mapper. int scoreColIndex; env.CheckDecode(mapper.OutputSchema.TryGetColumnIndex(scoreCol, out scoreColIndex)); var scoreType = mapper.OutputSchema[scoreColIndex].Type; env.CheckDecode(outputTypeMatches(scoreType)); var predColType = getPredColType(scoreType, rowMapper); return(new BindingsImpl(input, rowMapper, suffix, scoreKind, false, scoreColIndex, predColType)); }
public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) { Contracts.CheckValue(host, nameof(host)); Contracts.CheckValueOrNull(featureColumn); Host = host; Host.CheckValue(trainSchema, nameof(trainSchema)); Model = model; FeatureColumn = featureColumn; if (featureColumn == null) { FeatureColumnType = null; } else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) { throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); } else { FeatureColumnType = trainSchema.GetColumnType(col); } TrainSchema = trainSchema; BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); }
internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) { Host = host; ctx.LoadModel <TModel, SignatureLoadModel>(host, out TModel model, DirModel); Model = model; // *** Binary format *** // model: prediction model. // stream: empty data view that contains train schema. // id of string: feature column. // Clone the stream with the schema into memory. var ms = new MemoryStream(); ctx.TryLoadBinaryStream(DirTransSchema, reader => { reader.BaseStream.CopyTo(ms); }); ms.Position = 0; var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms); TrainSchema = loader.Schema; FeatureColumn = ctx.LoadString(); if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) { throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); } FeatureColumnType = TrainSchema.GetColumnType(col); BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); }
/// <summary> /// Gets a cached prediction engine or creates a new one if not cached /// </summary> private BatchPredictionEngine <SarUsageEvent, SarScoreResult> GetOrCreateBatchPredictionEngine( IList <SarUsageEvent> usageItems, SarScoringArguments sarScoringArguments) { var arguments = new RecommenderScorerTransform.Arguments { // round the recommendations count to optimize engine cache recommendationCount = GetRoundedNumberOfResults(sarScoringArguments.RecommendationCount), includeHistory = sarScoringArguments.IncludeHistory }; // create a data column mapping var dataColumnMapping = new Dictionary <RoleMappedSchema.ColumnRole, string> { { new RoleMappedSchema.ColumnRole("Item"), "Item" }, { new RoleMappedSchema.ColumnRole("User"), "user" } }; string weightColumn = null; if (sarScoringArguments.ReferenceDate.HasValue) { // rounding the reference date to the beginning of next day to optimize engine cache DateTime referenceDate = sarScoringArguments.ReferenceDate.Value.Date + TimeSpan.FromDays(1); arguments.referenceDate = referenceDate.ToString("s"); if (sarScoringArguments.Decay.HasValue) { arguments.decay = sarScoringArguments.Decay.Value.TotalDays; } dataColumnMapping.Add(new RoleMappedSchema.ColumnRole("Date"), "date"); weightColumn = "weight"; } // create an engine cache key string cacheKey = $"{arguments.recommendationCount}|{arguments.includeHistory}|{arguments.referenceDate}|{arguments.decay}"; _tracer.TraceVerbose("Trying to find the engine in the cache"); var engine = _enginesCache.Get(cacheKey) as BatchPredictionEngine <SarUsageEvent, SarScoreResult>; if (engine == null) { _tracer.TraceInformation("Engine is not cached - creating a new engine"); IDataView pipeline = _environment.CreateDataView(usageItems, _usageDataSchema); RoleMappedData usageDataMappedData = _environment.CreateExamples(pipeline, null, weight: weightColumn, custom: dataColumnMapping); ISchemaBindableMapper mapper = RecommenderScorerTransform.Create(_environment, arguments, _recommender); ISchemaBoundMapper boundMapper = mapper.Bind(_environment, usageDataMappedData.Schema); IDataScorerTransform scorer = RecommenderScorerTransform.Create( _environment, arguments, pipeline, boundMapper, null); engine = _environment.CreateBatchPredictionEngine <SarUsageEvent, SarScoreResult>(scorer, false, _usageDataSchema); bool result = _enginesCache.Add(cacheKey, engine, new CacheItemPolicy { SlidingExpiration = TimeSpan.FromDays(1) }); _tracer.TraceVerbose($"Addition of engine to the cache resulted with '{result}'"); } return(engine); }
private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor predictor, ICommandLineComponentFactory scorerSettings, out ISchemaBindableMapper bindable) { Contracts.AssertValue(env); env.AssertValue(predictor); env.AssertValue(scorerSettings); // Try to find a mapper factory method with the same loadname as the scorer settings. return ComponentCatalog.TryCreateInstance<ISchemaBindableMapper, SignatureBindableMapper>( env, out bindable, scorerSettings.Name, scorerSettings.GetSettingsString(), predictor); }
/// <summary> /// Deserialize the bindings, given the env, bindable and input schema. /// </summary> public static Bindings Create(ModelLoadContext ctx, IHostEnvironment env, ISchemaBindableMapper bindable, DataViewSchema input) { Contracts.AssertValue(ctx); // *** Binary format *** // <base info> string suffix; var roles = LoadBaseInfo(ctx, out suffix); return(Create(env, bindable, input, roles, suffix, user: false)); }
private LabelNameBindableMapper(IHostEnvironment env, ISchemaBindableMapper bindable, VectorType type, Delegate getter, string metadataKind, Func <ISchemaBoundMapper, ColumnType, bool> canWrap) { Contracts.AssertValue(env); _host = env.Register(LoaderSignature); _host.AssertValue(bindable); _host.AssertValue(type); _host.AssertValue(getter); _host.AssertNonEmpty(metadataKind); _host.AssertValueOrNull(canWrap); _bindable = bindable; _type = type; _getter = getter; _metadataKind = metadataKind; _canWrap = canWrap; }
public BindingsImpl ApplyToSchema(DataViewSchema input, ISchemaBindableMapper bindable, IHostEnvironment env) { Contracts.AssertValue(env); env.AssertValue(input); env.AssertValue(bindable); string scoreCol = RowMapper.OutputSchema[ScoreColumnIndex].Name; var schema = new RoleMappedSchema(input, RowMapper.GetInputColumnRoles()); // Checks compatibility of the predictor input types. var mapper = bindable.Bind(env, schema); var rowMapper = mapper as ISchemaBoundRowMapper; env.CheckParam(rowMapper != null, nameof(bindable), "Mapper must implement ISchemaBoundRowMapper"); int mapperScoreColumn; bool tmp = rowMapper.OutputSchema.TryGetColumnIndex(scoreCol, out mapperScoreColumn); env.Check(tmp, "Mapper doesn't have expected score column"); return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType); }
/// <summary> /// Create the bindings given the env, bindable, input schema, column roles, and column name suffix. /// </summary> private static Bindings Create(IHostEnvironment env, ISchemaBindableMapper bindable, DataViewSchema input, IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > roles, string suffix, bool user = true) { Contracts.AssertValue(env); Contracts.AssertValue(bindable); Contracts.AssertValue(input); Contracts.AssertValue(roles); Contracts.AssertValueOrNull(suffix); var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles)); // We don't actually depend on this invariant, but if this assert fires it means the bindable // did the wrong thing. Contracts.Assert(mapper.InputRoleMappedSchema.Schema == input); var rowMapper = mapper as ISchemaBoundRowMapper; Contracts.Check(rowMapper != null, "Predictor expected to be a RowMapper!"); return(Create(input, rowMapper, suffix, user)); }
public BindableMapper(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); _env = env; _env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // int: topContributionsCount // int: bottomContributionsCount // bool: normalize // bool: stringify ctx.LoadModel <IFeatureContributionMapper, SignatureLoadModel>(env, out Predictor, ModelFileUtils.DirPredictor); GenericMapper = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null); _topContributionsCount = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom); _bottomContributionsCount = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom); _normalize = ctx.Reader.ReadBoolByte(); Stringify = ctx.Reader.ReadBoolByte(); }
public BindableMapper(IHostEnvironment env, IFeatureContributionMapper predictor, int topContributionsCount, int bottomContributionsCount, bool normalize, bool stringify) { Contracts.CheckValue(env, nameof(env)); _env = env; _env.CheckValue(predictor, nameof(predictor)); if (topContributionsCount < 0) { throw env.Except($"Number of top contribution must be non negative"); } if (bottomContributionsCount < 0) { throw env.Except($"Number of bottom contribution must be non negative"); } _topContributionsCount = topContributionsCount; _bottomContributionsCount = bottomContributionsCount; _normalize = normalize; Stringify = stringify; Predictor = predictor; GenericMapper = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null); }
public ISchemaBindableMapper Clone(ISchemaBindableMapper inner) { return(new LabelNameBindableMapper(_host, inner, _type, _getter, _metadataKind, _canWrap)); }
private protected RowToRowScorerBase(IHostEnvironment env, IDataView input, string registrationName, ISchemaBindableMapper bindable) : base(env, registrationName, input) { Contracts.AssertValue(bindable); Bindable = bindable; }
private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor predictor, SubComponent <IDataScorerTransform, SignatureDataScorer> scorerSettings, out ISchemaBindableMapper bindable) { Contracts.AssertValue(env); env.AssertValue(predictor); env.Assert(scorerSettings.IsGood()); // Try to find a mapper factory method with the same loadname as the scorer settings. var mapperComponent = new SubComponent <ISchemaBindableMapper, SignatureBindableMapper>(scorerSettings.Kind, scorerSettings.Settings); return(ComponentCatalog.TryCreateInstance(env, out bindable, mapperComponent, predictor)); }