public static BindingsImpl Create(ModelLoadContext ctx, DataViewSchema input,
                                              IHostEnvironment env, ISchemaBindableMapper bindable,
                                              Func <DataViewType, bool> outputTypeMatches, Func <DataViewType, ISchemaBoundRowMapper, DataViewType> getPredColType)
            {
                Contracts.AssertValue(env);
                env.AssertValue(ctx);

                // *** Binary format ***
                // <base info>
                // int: id of the scores column kind (metadata output)
                // int: id of the column used for deriving the predicted label column

                string suffix;
                var    roles = LoadBaseInfo(ctx, out suffix);

                string scoreKind = ctx.LoadNonEmptyString();
                string scoreCol  = ctx.LoadNonEmptyString();

                var mapper    = bindable.Bind(env, new RoleMappedSchema(input, roles));
                var rowMapper = mapper as ISchemaBoundRowMapper;

                env.CheckParam(rowMapper != null, nameof(bindable), "Bindable expected to be an " + nameof(ISchemaBindableMapper) + "!");

                // Find the score column of the mapper.
                int scoreColIndex;

                env.CheckDecode(mapper.OutputSchema.TryGetColumnIndex(scoreCol, out scoreColIndex));

                var scoreType = mapper.OutputSchema[scoreColIndex].Type;

                env.CheckDecode(outputTypeMatches(scoreType));
                var predColType = getPredColType(scoreType, rowMapper);

                return(new BindingsImpl(input, rowMapper, suffix, scoreKind, false, scoreColIndex, predColType));
            }
예제 #2
0
        /// <summary>
        /// Gets a cached prediction engine or creates a new one if not cached
        /// </summary>
        private BatchPredictionEngine <SarUsageEvent, SarScoreResult> GetOrCreateBatchPredictionEngine(
            IList <SarUsageEvent> usageItems, SarScoringArguments sarScoringArguments)
        {
            var arguments = new RecommenderScorerTransform.Arguments
            {
                // round the recommendations count to optimize engine cache
                recommendationCount = GetRoundedNumberOfResults(sarScoringArguments.RecommendationCount),
                includeHistory      = sarScoringArguments.IncludeHistory
            };

            // create a data column mapping
            var dataColumnMapping = new Dictionary <RoleMappedSchema.ColumnRole, string>
            {
                { new RoleMappedSchema.ColumnRole("Item"), "Item" },
                { new RoleMappedSchema.ColumnRole("User"), "user" }
            };

            string weightColumn = null;

            if (sarScoringArguments.ReferenceDate.HasValue)
            {
                // rounding the reference date to the beginning of next day to optimize engine cache
                DateTime referenceDate = sarScoringArguments.ReferenceDate.Value.Date + TimeSpan.FromDays(1);
                arguments.referenceDate = referenceDate.ToString("s");
                if (sarScoringArguments.Decay.HasValue)
                {
                    arguments.decay = sarScoringArguments.Decay.Value.TotalDays;
                }

                dataColumnMapping.Add(new RoleMappedSchema.ColumnRole("Date"), "date");
                weightColumn = "weight";
            }

            // create an engine cache key
            string cacheKey = $"{arguments.recommendationCount}|{arguments.includeHistory}|{arguments.referenceDate}|{arguments.decay}";

            _tracer.TraceVerbose("Trying to find the engine in the cache");
            var engine = _enginesCache.Get(cacheKey) as BatchPredictionEngine <SarUsageEvent, SarScoreResult>;

            if (engine == null)
            {
                _tracer.TraceInformation("Engine is not cached - creating a new engine");
                IDataView             pipeline            = _environment.CreateDataView(usageItems, _usageDataSchema);
                RoleMappedData        usageDataMappedData = _environment.CreateExamples(pipeline, null, weight: weightColumn, custom: dataColumnMapping);
                ISchemaBindableMapper mapper      = RecommenderScorerTransform.Create(_environment, arguments, _recommender);
                ISchemaBoundMapper    boundMapper = mapper.Bind(_environment, usageDataMappedData.Schema);
                IDataScorerTransform  scorer      = RecommenderScorerTransform.Create(
                    _environment, arguments, pipeline, boundMapper, null);
                engine = _environment.CreateBatchPredictionEngine <SarUsageEvent, SarScoreResult>(scorer, false, _usageDataSchema);

                bool result = _enginesCache.Add(cacheKey, engine, new CacheItemPolicy {
                    SlidingExpiration = TimeSpan.FromDays(1)
                });
                _tracer.TraceVerbose($"Addition of engine to the cache resulted with '{result}'");
            }

            return(engine);
        }
            ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
            {
                var innerBound = _bindable.Bind(env, schema);

                if (_canWrap != null && !_canWrap(innerBound, _type))
                {
                    return(innerBound);
                }
                Contracts.Assert(innerBound is ISchemaBoundRowMapper);
                return(Utils.MarshalInvoke(CreateBound <int>, _type.ItemType.RawType, env, (ISchemaBoundRowMapper)innerBound, _type, _getter, _metadataKind, _canWrap));
            }
            public BindingsImpl ApplyToSchema(DataViewSchema input, ISchemaBindableMapper bindable, IHostEnvironment env)
            {
                Contracts.AssertValue(env);
                env.AssertValue(input);
                env.AssertValue(bindable);

                string scoreCol = RowMapper.OutputSchema[ScoreColumnIndex].Name;
                var schema = new RoleMappedSchema(input, RowMapper.GetInputColumnRoles());

                // Checks compatibility of the predictor input types.
                var mapper = bindable.Bind(env, schema);
                var rowMapper = mapper as ISchemaBoundRowMapper;
                env.CheckParam(rowMapper != null, nameof(bindable), "Mapper must implement ISchemaBoundRowMapper");
                int mapperScoreColumn;
                bool tmp = rowMapper.OutputSchema.TryGetColumnIndex(scoreCol, out mapperScoreColumn);
                env.Check(tmp, "Mapper doesn't have expected score column");

                return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType);
            }
예제 #5
0
            /// <summary>
            /// Create the bindings given the env, bindable, input schema, column roles, and column name suffix.
            /// </summary>
            private static Bindings Create(IHostEnvironment env, ISchemaBindableMapper bindable, DataViewSchema input,
                                           IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > roles, string suffix, bool user = true)
            {
                Contracts.AssertValue(env);
                Contracts.AssertValue(bindable);
                Contracts.AssertValue(input);
                Contracts.AssertValue(roles);
                Contracts.AssertValueOrNull(suffix);

                var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles));

                // We don't actually depend on this invariant, but if this assert fires it means the bindable
                // did the wrong thing.
                Contracts.Assert(mapper.InputRoleMappedSchema.Schema == input);

                var rowMapper = mapper as ISchemaBoundRowMapper;

                Contracts.Check(rowMapper != null, "Predictor expected to be a RowMapper!");

                return(Create(input, rowMapper, suffix, user));
            }