public static BindingsImpl Create(ModelLoadContext ctx, DataViewSchema input,
                                              IHostEnvironment env, ISchemaBindableMapper bindable,
                                              Func <DataViewType, bool> outputTypeMatches, Func <DataViewType, ISchemaBoundRowMapper, DataViewType> getPredColType)
            {
                Contracts.AssertValue(env);
                env.AssertValue(ctx);

                // *** Binary format ***
                // <base info>
                // int: id of the scores column kind (metadata output)
                // int: id of the column used for deriving the predicted label column

                string suffix;
                var    roles = LoadBaseInfo(ctx, out suffix);

                string scoreKind = ctx.LoadNonEmptyString();
                string scoreCol  = ctx.LoadNonEmptyString();

                var mapper    = bindable.Bind(env, new RoleMappedSchema(input, roles));
                var rowMapper = mapper as ISchemaBoundRowMapper;

                env.CheckParam(rowMapper != null, nameof(bindable), "Bindable expected to be an " + nameof(ISchemaBindableMapper) + "!");

                // Find the score column of the mapper.
                int scoreColIndex;

                env.CheckDecode(mapper.OutputSchema.TryGetColumnIndex(scoreCol, out scoreColIndex));

                var scoreType = mapper.OutputSchema[scoreColIndex].Type;

                env.CheckDecode(outputTypeMatches(scoreType));
                var predColType = getPredColType(scoreType, rowMapper);

                return(new BindingsImpl(input, rowMapper, suffix, scoreKind, false, scoreColIndex, predColType));
            }
Beispiel #2
0
        public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
        {
            Contracts.CheckValue(host, nameof(host));
            Contracts.CheckValueOrNull(featureColumn);
            Host = host;
            Host.CheckValue(trainSchema, nameof(trainSchema));

            Model         = model;
            FeatureColumn = featureColumn;
            if (featureColumn == null)
            {
                FeatureColumnType = null;
            }
            else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
            {
                throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
            }
            else
            {
                FeatureColumnType = trainSchema.GetColumnType(col);
            }

            TrainSchema    = trainSchema;
            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
        }
Beispiel #3
0
        internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
        {
            Host = host;

            ctx.LoadModel <TModel, SignatureLoadModel>(host, out TModel model, DirModel);
            Model = model;

            // *** Binary format ***
            // model: prediction model.
            // stream: empty data view that contains train schema.
            // id of string: feature column.

            // Clone the stream with the schema into memory.
            var ms = new MemoryStream();

            ctx.TryLoadBinaryStream(DirTransSchema, reader =>
            {
                reader.BaseStream.CopyTo(ms);
            });

            ms.Position = 0;
            var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);

            TrainSchema = loader.Schema;

            FeatureColumn = ctx.LoadString();
            if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
            {
                throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
            }
            FeatureColumnType = TrainSchema.GetColumnType(col);

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
        }
Beispiel #4
0
        /// <summary>
        /// Gets a cached prediction engine or creates a new one if not cached
        /// </summary>
        private BatchPredictionEngine <SarUsageEvent, SarScoreResult> GetOrCreateBatchPredictionEngine(
            IList <SarUsageEvent> usageItems, SarScoringArguments sarScoringArguments)
        {
            var arguments = new RecommenderScorerTransform.Arguments
            {
                // round the recommendations count to optimize engine cache
                recommendationCount = GetRoundedNumberOfResults(sarScoringArguments.RecommendationCount),
                includeHistory      = sarScoringArguments.IncludeHistory
            };

            // create a data column mapping
            var dataColumnMapping = new Dictionary <RoleMappedSchema.ColumnRole, string>
            {
                { new RoleMappedSchema.ColumnRole("Item"), "Item" },
                { new RoleMappedSchema.ColumnRole("User"), "user" }
            };

            string weightColumn = null;

            if (sarScoringArguments.ReferenceDate.HasValue)
            {
                // rounding the reference date to the beginning of next day to optimize engine cache
                DateTime referenceDate = sarScoringArguments.ReferenceDate.Value.Date + TimeSpan.FromDays(1);
                arguments.referenceDate = referenceDate.ToString("s");
                if (sarScoringArguments.Decay.HasValue)
                {
                    arguments.decay = sarScoringArguments.Decay.Value.TotalDays;
                }

                dataColumnMapping.Add(new RoleMappedSchema.ColumnRole("Date"), "date");
                weightColumn = "weight";
            }

            // create an engine cache key
            string cacheKey = $"{arguments.recommendationCount}|{arguments.includeHistory}|{arguments.referenceDate}|{arguments.decay}";

            _tracer.TraceVerbose("Trying to find the engine in the cache");
            var engine = _enginesCache.Get(cacheKey) as BatchPredictionEngine <SarUsageEvent, SarScoreResult>;

            if (engine == null)
            {
                _tracer.TraceInformation("Engine is not cached - creating a new engine");
                IDataView             pipeline            = _environment.CreateDataView(usageItems, _usageDataSchema);
                RoleMappedData        usageDataMappedData = _environment.CreateExamples(pipeline, null, weight: weightColumn, custom: dataColumnMapping);
                ISchemaBindableMapper mapper      = RecommenderScorerTransform.Create(_environment, arguments, _recommender);
                ISchemaBoundMapper    boundMapper = mapper.Bind(_environment, usageDataMappedData.Schema);
                IDataScorerTransform  scorer      = RecommenderScorerTransform.Create(
                    _environment, arguments, pipeline, boundMapper, null);
                engine = _environment.CreateBatchPredictionEngine <SarUsageEvent, SarScoreResult>(scorer, false, _usageDataSchema);

                bool result = _enginesCache.Add(cacheKey, engine, new CacheItemPolicy {
                    SlidingExpiration = TimeSpan.FromDays(1)
                });
                _tracer.TraceVerbose($"Addition of engine to the cache resulted with '{result}'");
            }

            return(engine);
        }
Beispiel #5
0
        private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor predictor,
            ICommandLineComponentFactory scorerSettings, out ISchemaBindableMapper bindable)
        {
            Contracts.AssertValue(env);
            env.AssertValue(predictor);
            env.AssertValue(scorerSettings);

            // Try to find a mapper factory method with the same loadname as the scorer settings.
            return ComponentCatalog.TryCreateInstance<ISchemaBindableMapper, SignatureBindableMapper>(
                env, out bindable, scorerSettings.Name, scorerSettings.GetSettingsString(), predictor);
        }
            /// <summary>
            /// Deserialize the bindings, given the env, bindable and input schema.
            /// </summary>
            public static Bindings Create(ModelLoadContext ctx,
                                          IHostEnvironment env, ISchemaBindableMapper bindable, DataViewSchema input)
            {
                Contracts.AssertValue(ctx);

                // *** Binary format ***
                // <base info>
                string suffix;
                var    roles = LoadBaseInfo(ctx, out suffix);

                return(Create(env, bindable, input, roles, suffix, user: false));
            }
Beispiel #7
0
            private LabelNameBindableMapper(IHostEnvironment env, ISchemaBindableMapper bindable, VectorType type, Delegate getter,
                                            string metadataKind, Func <ISchemaBoundMapper, ColumnType, bool> canWrap)
            {
                Contracts.AssertValue(env);
                _host = env.Register(LoaderSignature);
                _host.AssertValue(bindable);
                _host.AssertValue(type);
                _host.AssertValue(getter);
                _host.AssertNonEmpty(metadataKind);
                _host.AssertValueOrNull(canWrap);

                _bindable     = bindable;
                _type         = type;
                _getter       = getter;
                _metadataKind = metadataKind;
                _canWrap      = canWrap;
            }
            public BindingsImpl ApplyToSchema(DataViewSchema input, ISchemaBindableMapper bindable, IHostEnvironment env)
            {
                Contracts.AssertValue(env);
                env.AssertValue(input);
                env.AssertValue(bindable);

                string scoreCol = RowMapper.OutputSchema[ScoreColumnIndex].Name;
                var schema = new RoleMappedSchema(input, RowMapper.GetInputColumnRoles());

                // Checks compatibility of the predictor input types.
                var mapper = bindable.Bind(env, schema);
                var rowMapper = mapper as ISchemaBoundRowMapper;
                env.CheckParam(rowMapper != null, nameof(bindable), "Mapper must implement ISchemaBoundRowMapper");
                int mapperScoreColumn;
                bool tmp = rowMapper.OutputSchema.TryGetColumnIndex(scoreCol, out mapperScoreColumn);
                env.Check(tmp, "Mapper doesn't have expected score column");

                return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType);
            }
            /// <summary>
            /// Create the bindings given the env, bindable, input schema, column roles, and column name suffix.
            /// </summary>
            private static Bindings Create(IHostEnvironment env, ISchemaBindableMapper bindable, DataViewSchema input,
                                           IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > roles, string suffix, bool user = true)
            {
                Contracts.AssertValue(env);
                Contracts.AssertValue(bindable);
                Contracts.AssertValue(input);
                Contracts.AssertValue(roles);
                Contracts.AssertValueOrNull(suffix);

                var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles));

                // We don't actually depend on this invariant, but if this assert fires it means the bindable
                // did the wrong thing.
                Contracts.Assert(mapper.InputRoleMappedSchema.Schema == input);

                var rowMapper = mapper as ISchemaBoundRowMapper;

                Contracts.Check(rowMapper != null, "Predictor expected to be a RowMapper!");

                return(Create(input, rowMapper, suffix, user));
            }
            public BindableMapper(IHostEnvironment env, ModelLoadContext ctx)
            {
                Contracts.CheckValue(env, nameof(env));
                _env = env;
                _env.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel(GetVersionInfo());

                // *** Binary format ***
                // int: topContributionsCount
                // int: bottomContributionsCount
                // bool: normalize
                // bool: stringify
                ctx.LoadModel <IFeatureContributionMapper, SignatureLoadModel>(env, out Predictor, ModelFileUtils.DirPredictor);
                GenericMapper          = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null);
                _topContributionsCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom);
                _bottomContributionsCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom);
                _normalize = ctx.Reader.ReadBoolByte();
                Stringify  = ctx.Reader.ReadBoolByte();
            }
            public BindableMapper(IHostEnvironment env, IFeatureContributionMapper predictor, int topContributionsCount, int bottomContributionsCount, bool normalize, bool stringify)
            {
                Contracts.CheckValue(env, nameof(env));
                _env = env;
                _env.CheckValue(predictor, nameof(predictor));
                if (topContributionsCount < 0)
                {
                    throw env.Except($"Number of top contribution must be non negative");
                }
                if (bottomContributionsCount < 0)
                {
                    throw env.Except($"Number of bottom contribution must be non negative");
                }

                _topContributionsCount    = topContributionsCount;
                _bottomContributionsCount = bottomContributionsCount;
                _normalize = normalize;
                Stringify  = stringify;
                Predictor  = predictor;

                GenericMapper = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null);
            }
Beispiel #12
0
 public ISchemaBindableMapper Clone(ISchemaBindableMapper inner)
 {
     return(new LabelNameBindableMapper(_host, inner, _type, _getter, _metadataKind, _canWrap));
 }
Beispiel #13
0
 private protected RowToRowScorerBase(IHostEnvironment env, IDataView input, string registrationName, ISchemaBindableMapper bindable)
     : base(env, registrationName, input)
 {
     Contracts.AssertValue(bindable);
     Bindable = bindable;
 }
Beispiel #14
0
        private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor predictor,
                                                        SubComponent <IDataScorerTransform, SignatureDataScorer> scorerSettings, out ISchemaBindableMapper bindable)
        {
            Contracts.AssertValue(env);
            env.AssertValue(predictor);
            env.Assert(scorerSettings.IsGood());

            // Try to find a mapper factory method with the same loadname as the scorer settings.
            var mapperComponent = new SubComponent <ISchemaBindableMapper, SignatureBindableMapper>(scorerSettings.Kind, scorerSettings.Settings);

            return(ComponentCatalog.TryCreateInstance(env, out bindable, mapperComponent, predictor));
        }