/// <summary>
        /// The counter constructor of re-creating <see cref="MatrixFactorizationPredictionTransformer"/> from the context where
        /// the original transform is saved.
        /// </summary>
        public MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(MatrixFactorizationPredictionTransformer)), ctx)
        {
            // *** Binary format ***
            // <base info>
            // string: the column name of matrix's column ids.
            // string: the column name of matrix's row ids.

            MatrixColumnIndexColumnName = ctx.LoadString();
            MatrixRowIndexColumnName    = ctx.LoadString();

            if (!TrainSchema.TryGetColumnIndex(MatrixColumnIndexColumnName, out int xCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(MatrixColumnIndexColumnName), RecommenderUtils.MatrixColumnIndexKind.Value, MatrixColumnIndexColumnName);
            }
            MatrixColumnIndexColumnType = TrainSchema.GetColumnType(xCol);

            if (!TrainSchema.TryGetColumnIndex(MatrixRowIndexColumnName, out int yCol))
            {
                throw Host.ExceptSchemaMismatch(nameof(MatrixRowIndexColumnName), RecommenderUtils.MatrixRowIndexKind.Value, MatrixRowIndexColumnName);
            }
            MatrixRowIndexColumnType = TrainSchema.GetColumnType(yCol);

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);

            var schema = GetSchema();
            var args   = new GenericScorer.Arguments {
                Suffix = ""
            };

            Scorer = new GenericScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
        }
        private protected SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx, TModel model)
            : base(host, ctx, model)
        {
            FeatureColumnName = ctx.LoadStringOrNull();

            if (FeatureColumnName == null)
                FeatureColumnType = null;
            else if (!TrainSchema.TryGetColumnIndex(FeatureColumnName, out int col))
                throw Host.ExceptSchemaMismatch(nameof(FeatureColumnName), "feature", FeatureColumnName);
            else
                FeatureColumnType = TrainSchema[col].Type;

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, ModelAsPredictor);
        }
示例#3
0
        internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx)
            : base(host, ctx)
        {
            FeatureColumn = ctx.LoadStringOrNull();

            if (FeatureColumn == null)
            {
                FeatureColumnType = null;
            }
            else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
            {
                throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
            }
            else
            {
                FeatureColumnType = TrainSchema[col].Type;
            }

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);
        }
        internal FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, ModelLoadContext ctx)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), ctx)
        {
            // *** Binary format ***
            // <base info>
            // ids of strings: feature columns.
            // float: scorer threshold
            // id of string: scorer threshold column

            // count of feature columns. FAFM uses more than one.
            int featCount = Model.FieldCount;

            var featureColumns     = new string[featCount];
            var featureColumnTypes = new DataViewType[featCount];

            for (int i = 0; i < featCount; i++)
            {
                featureColumns[i] = ctx.LoadString();
                if (!TrainSchema.TryGetColumnIndex(featureColumns[i], out int col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(FeatureColumns), "feature", featureColumns[i]);
                }
                featureColumnTypes[i] = TrainSchema[col].Type;
            }
            FeatureColumns     = featureColumns;
            FeatureColumnTypes = featureColumnTypes;

            _threshold       = ctx.Reader.ReadSingle();
            _thresholdColumn = ctx.LoadString();

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);

            var schema = GetSchema();
            var args   = new BinaryClassifierScorer.Arguments {
                Threshold = _threshold, ThresholdColumn = _thresholdColumn
            };

            Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
        }