internal FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachineModelParameters model, DataViewSchema trainSchema,
                                                                     string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema)
        {
            Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn));
            Host.CheckNonEmpty(featureColumns, nameof(featureColumns));

            _threshold       = threshold;
            _thresholdColumn = thresholdColumn;
            FeatureColumns   = featureColumns;
            var featureColumnTypes = new DataViewType[featureColumns.Length];

            int i = 0;

            foreach (var feat in featureColumns)
            {
                if (!trainSchema.TryGetColumnIndex(feat, out int col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(featureColumns), "feature", feat);
                }
                featureColumnTypes[i++] = trainSchema[col].Type;
            }
            FeatureColumnTypes = featureColumnTypes;

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);

            var schema = GetSchema();
            var args   = new BinaryClassifierScorer.Arguments {
                Threshold = _threshold, ThresholdColumn = _thresholdColumn
            };

            Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema);
        }
        void ReconfigurablePrediction()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader);

                // Train
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });

                var        cached     = new CacheDataView(env, trans, prefetch: null);
                var        trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                IPredictor predictor  = trainer.Train(new Runtime.TrainContext(trainRoles));
                using (var ch = env.Start("Calibrator training"))
                {
                    predictor = CalibratorUtils.TrainCalibrator(env, ch, new PlattCalibratorTrainer(env), int.MaxValue, predictor, trainRoles);
                }

                var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true);

                var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                {
                });
                var metricsDict = evaluator.Evaluate(dataEval);

                var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];

                var bindable  = ScoreUtils.GetSchemaBindableMapper(env, predictor, null);
                var mapper    = bindable.Bind(env, trainRoles.Schema);
                var newScorer = new BinaryClassifierScorer(env, new BinaryClassifierScorer.Arguments {
                    Threshold = 0.01f, ThresholdColumn = DefaultColumnNames.Probability
                },
                                                           scoreRoles.Data, mapper, trainRoles.Schema);

                dataEval = new RoleMappedData(newScorer, label: "Label", feature: "Features", opt: true);
                var newEvaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                {
                    Threshold = 0.01f, UseRawScoreThreshold = false
                });
                metricsDict = newEvaluator.Evaluate(dataEval);
                var newMetrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];
            }
        }
        internal FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, ModelLoadContext ctx)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), ctx)
        {
            // *** Binary format ***
            // <base info>
            // ids of strings: feature columns.
            // float: scorer threshold
            // id of string: scorer threshold column

            // count of feature columns. FAFM uses more than one.
            int featCount = Model.FieldCount;

            var featureColumns     = new string[featCount];
            var featureColumnTypes = new DataViewType[featCount];

            for (int i = 0; i < featCount; i++)
            {
                featureColumns[i] = ctx.LoadString();
                if (!TrainSchema.TryGetColumnIndex(featureColumns[i], out int col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(FeatureColumns), "feature", featureColumns[i]);
                }
                featureColumnTypes[i] = TrainSchema[col].Type;
            }
            FeatureColumns     = featureColumns;
            FeatureColumnTypes = featureColumnTypes;

            _threshold       = ctx.Reader.ReadSingle();
            _thresholdColumn = ctx.LoadString();

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);

            var schema = GetSchema();
            var args   = new BinaryClassifierScorer.Arguments {
                Threshold = _threshold, ThresholdColumn = _thresholdColumn
            };

            Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
        }
        public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachinePredictor model, ISchema trainSchema,
                                                                   string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema)
        {
            Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn));
            _threshold       = threshold;
            _thresholdColumn = thresholdColumn;

            Host.CheckValue(featureColumns, nameof(featureColumns));
            int featCount = featureColumns.Length;

            Host.Check(featCount >= 0, "Empty features column.");

            FeatureColumns     = featureColumns;
            FeatureColumnTypes = new ColumnType[featCount];

            int i = 0;

            foreach (var feat in featureColumns)
            {
                if (!trainSchema.TryGetColumnIndex(feat, out int col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(featureColumns), RoleMappedSchema.ColumnRole.Feature.Value, feat);
                }
                FeatureColumnTypes[i++] = trainSchema.GetColumnType(col);
            }

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);

            var schema = GetSchema();
            var args   = new BinaryClassifierScorer.Arguments {
                Threshold = _threshold, ThresholdColumn = _thresholdColumn
            };

            _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema);
        }