protected override IPredictorWithFeatureWeights <float> TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; ValidData = context.ValidationSet; using (var ch = Host.Start("Training")) { ch.CheckValue(trainData, nameof(trainData)); trainData.CheckBinaryLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); FeatureCount = trainData.Schema.Feature.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } // The FastTree binary classification boosting is naturally calibrated to // output probabilities when transformed using a scaled logistic function, // so transform the scores using that. var pred = new FastTreeBinaryPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); // FastTree's binary classification boosting framework's natural probabilistic interpretation // is explained in "From RankNet to LambdaRank to LambdaMART: An Overview" by Chris Burges. // The correctness of this scaling depends upon the gradient calculation in // BinaryClassificationObjectiveFunction.GetGradientInOneQuery being consistent with the // description in section 6 of the paper. var cali = new PlattCalibrator(Host, -1 * _sigmoidParameter, 0); return(new FeatureWeightsCalibratedPredictor(Host, pred, cali)); }
public static IPredictorProducing <float> Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); var predictor = new FastTreeBinaryPredictor(env, ctx); ICalibrator calibrator; ctx.LoadModelOrNull <ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator"); if (calibrator == null) { return(predictor); } return(new SchemaBindableCalibratedPredictor(env, predictor, calibrator)); }