internal FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachineModelParameters model, DataViewSchema trainSchema, string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema) { Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); Host.CheckNonEmpty(featureColumns, nameof(featureColumns)); _threshold = threshold; _thresholdColumn = thresholdColumn; FeatureColumns = featureColumns; var featureColumnTypes = new DataViewType[featureColumns.Length]; int i = 0; foreach (var feat in featureColumns) { if (!trainSchema.TryGetColumnIndex(feat, out int col)) { throw Host.ExceptSchemaMismatch(nameof(featureColumns), "feature", feat); } featureColumnTypes[i++] = trainSchema[col].Type; } FeatureColumnTypes = featureColumnTypes; BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); var schema = GetSchema(); var args = new BinaryClassifierScorer.Arguments { Threshold = _threshold, ThresholdColumn = _thresholdColumn }; Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema); }
void ReconfigurablePrediction() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new TlcEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); // Train var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); IPredictor predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); using (var ch = env.Start("Calibrator training")) { predictor = CalibratorUtils.TrainCalibrator(env, ch, new PlattCalibratorTrainer(env), int.MaxValue, predictor, trainRoles); } var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true); var evaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { }); var metricsDict = evaluator.Evaluate(dataEval); var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0]; var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor, null); var mapper = bindable.Bind(env, trainRoles.Schema); var newScorer = new BinaryClassifierScorer(env, new BinaryClassifierScorer.Arguments { Threshold = 0.01f, ThresholdColumn = DefaultColumnNames.Probability }, scoreRoles.Data, mapper, trainRoles.Schema); dataEval = new RoleMappedData(newScorer, label: "Label", feature: "Features", opt: true); var newEvaluator = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { Threshold = 0.01f, UseRawScoreThreshold = false }); metricsDict = newEvaluator.Evaluate(dataEval); var newMetrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0]; } }
internal FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, ModelLoadContext ctx) : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), ctx) { // *** Binary format *** // <base info> // ids of strings: feature columns. // float: scorer threshold // id of string: scorer threshold column // count of feature columns. FAFM uses more than one. int featCount = Model.FieldCount; var featureColumns = new string[featCount]; var featureColumnTypes = new DataViewType[featCount]; for (int i = 0; i < featCount; i++) { featureColumns[i] = ctx.LoadString(); if (!TrainSchema.TryGetColumnIndex(featureColumns[i], out int col)) { throw Host.ExceptSchemaMismatch(nameof(FeatureColumns), "feature", featureColumns[i]); } featureColumnTypes[i] = TrainSchema[col].Type; } FeatureColumns = featureColumns; FeatureColumnTypes = featureColumnTypes; _threshold = ctx.Reader.ReadSingle(); _thresholdColumn = ctx.LoadString(); BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); var schema = GetSchema(); var args = new BinaryClassifierScorer.Arguments { Threshold = _threshold, ThresholdColumn = _thresholdColumn }; Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); }
public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachinePredictor model, ISchema trainSchema, string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema) { Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); _threshold = threshold; _thresholdColumn = thresholdColumn; Host.CheckValue(featureColumns, nameof(featureColumns)); int featCount = featureColumns.Length; Host.Check(featCount >= 0, "Empty features column."); FeatureColumns = featureColumns; FeatureColumnTypes = new ColumnType[featCount]; int i = 0; foreach (var feat in featureColumns) { if (!trainSchema.TryGetColumnIndex(feat, out int col)) { throw Host.ExceptSchemaMismatch(nameof(featureColumns), RoleMappedSchema.ColumnRole.Feature.Value, feat); } FeatureColumnTypes[i++] = trainSchema.GetColumnType(col); } BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); var schema = GetSchema(); var args = new BinaryClassifierScorer.Arguments { Threshold = _threshold, ThresholdColumn = _thresholdColumn }; _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema); }