internal FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachineModelParameters model, DataViewSchema trainSchema,
                                                                     string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema)
        {
            Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn));
            Host.CheckNonEmpty(featureColumns, nameof(featureColumns));

            _threshold       = threshold;
            _thresholdColumn = thresholdColumn;
            FeatureColumns   = featureColumns;
            var featureColumnTypes = new DataViewType[featureColumns.Length];

            int i = 0;

            foreach (var feat in featureColumns)
            {
                if (!trainSchema.TryGetColumnIndex(feat, out int col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(featureColumns), "feature", feat);
                }
                featureColumnTypes[i++] = trainSchema[col].Type;
            }
            FeatureColumnTypes = featureColumnTypes;

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);

            var schema = GetSchema();
            var args   = new BinaryClassifierScorer.Arguments {
                Threshold = _threshold, ThresholdColumn = _thresholdColumn
            };

            Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema);
        }
Пример #2
0
        private void SetScorer()
        {
            var schema = new RoleMappedSchema(TrainSchema, LabelColumnName, FeatureColumnName);
            var args   = new BinaryClassifierScorer.Arguments {
                Threshold = Threshold, ThresholdColumn = ThresholdColumn
            };

            Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
        }
Пример #3
0
        public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn,
                                           float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer <TModel>)), model, inputSchema, featureColumn)
        {
            Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn));
            var schema = new RoleMappedSchema(inputSchema, null, featureColumn);

            Threshold       = threshold;
            ThresholdColumn = thresholdColumn;

            var args = new BinaryClassifierScorer.Arguments {
                Threshold = Threshold, ThresholdColumn = ThresholdColumn
            };

            _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema);
        }
Пример #4
0
        public BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer <TModel>)), ctx)
        {
            // *** Binary format ***
            // <base info>
            // float: scorer threshold
            // id of string: scorer threshold column

            Threshold       = ctx.Reader.ReadSingle();
            ThresholdColumn = ctx.LoadString();

            var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
            var args   = new BinaryClassifierScorer.Arguments {
                Threshold = Threshold, ThresholdColumn = ThresholdColumn
            };

            _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
        }
        internal FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, ModelLoadContext ctx)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), ctx)
        {
            // *** Binary format ***
            // <base info>
            // ids of strings: feature columns.
            // float: scorer threshold
            // id of string: scorer threshold column

            // count of feature columns. FAFM uses more than one.
            int featCount = Model.FieldCount;

            var featureColumns     = new string[featCount];
            var featureColumnTypes = new DataViewType[featCount];

            for (int i = 0; i < featCount; i++)
            {
                featureColumns[i] = ctx.LoadString();
                if (!TrainSchema.TryGetColumnIndex(featureColumns[i], out int col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(FeatureColumns), "feature", featureColumns[i]);
                }
                featureColumnTypes[i] = TrainSchema[col].Type;
            }
            FeatureColumns     = featureColumns;
            FeatureColumnTypes = featureColumnTypes;

            _threshold       = ctx.Reader.ReadSingle();
            _thresholdColumn = ctx.LoadString();

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);

            var schema = GetSchema();
            var args   = new BinaryClassifierScorer.Arguments {
                Threshold = _threshold, ThresholdColumn = _thresholdColumn
            };

            Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
        }
Пример #6
0
        public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachinePredictor model, Schema trainSchema,
                                                                   string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema)
        {
            Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn));
            _threshold       = threshold;
            _thresholdColumn = thresholdColumn;

            Host.CheckValue(featureColumns, nameof(featureColumns));
            int featCount = featureColumns.Length;

            Host.Check(featCount >= 0, "Empty features column.");

            FeatureColumns     = featureColumns;
            FeatureColumnTypes = new ColumnType[featCount];

            int i = 0;

            foreach (var feat in featureColumns)
            {
                if (!trainSchema.TryGetColumnIndex(feat, out int col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(featureColumns), RoleMappedSchema.ColumnRole.Feature.Value, feat);
                }
                FeatureColumnTypes[i++] = trainSchema.GetColumnType(col);
            }

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);

            var schema = GetSchema();
            var args   = new BinaryClassifierScorer.Arguments {
                Threshold = _threshold, ThresholdColumn = _thresholdColumn
            };

            Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema);
        }
Пример #7
0
        public BinaryScorerWrapper <TModel> Clone(BinaryClassifierScorer.Arguments scorerArgs)
        {
            var scorer = _xf as IDataScorerTransform;

            return(new BinaryScorerWrapper <TModel>(_env, InnerModel, scorer.Source.Schema, _featureColumn, scorerArgs));
        }
Пример #8
0
        private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string featureColumn, TModel model, BinaryClassifierScorer.Arguments args)
        {
            var settings = $"Binary{{{CmdParser.GetSettings(env, args, new BinaryClassifierScorer.Arguments())}}}";

            var scorerFactorySettings = CmdParser.CreateComponentFactory(
                typeof(IComponentFactory <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>),
                typeof(SignatureDataScorer),
                settings);

            var bindable = ScoreUtils.GetSchemaBindableMapper(env, model, scorerFactorySettings: scorerFactorySettings);
            var edv      = new EmptyDataView(env, schema);
            var data     = new RoleMappedData(edv, "Label", featureColumn, opt: true);

            return(new BinaryClassifierScorer(env, args, data.Data, bindable.Bind(env, data.Schema), data.Schema));
        }
Пример #9
0
 public BinaryScorerWrapper(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, BinaryClassifierScorer.Arguments args)
     : base(env, MakeScorer(env, inputSchema, featureColumn, model, args), model, featureColumn)
 {
 }