private protected override IPredictorWithFeatureWeights <float> TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; ValidData = context.ValidationSet; TestData = context.TestSet; using (var ch = Host.Start("Training")) { ch.CheckValue(trainData, nameof(trainData)); trainData.CheckBinaryLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); FeatureCount = trainData.Schema.Feature.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } // The FastTree binary classification boosting is naturally calibrated to // output probabilities when transformed using a scaled logistic function, // so transform the scores using that. var pred = new FastTreeBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); // FastTree's binary classification boosting framework's natural probabilistic interpretation // is explained in "From RankNet to LambdaRank to LambdaMART: An Overview" by Chris Burges. // The correctness of this scaling depends upon the gradient calculation in // BinaryClassificationObjectiveFunction.GetGradientInOneQuery being consistent with the // description in section 6 of the paper. var cali = new PlattCalibrator(Host, -1 * _sigmoidParameter, 0); return(new FeatureWeightsCalibratedPredictor(Host, pred, cali)); }
private static IPredictorProducing <float> Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); var predictor = new FastTreeBinaryModelParameters(env, ctx); ICalibrator calibrator; ctx.LoadModelOrNull <ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator"); if (calibrator == null) { return(predictor); } return(new SchemaBindableCalibratedModelParameters <FastTreeBinaryModelParameters, ICalibrator>(env, predictor, calibrator)); }
IPredictor IModelCombiner.CombineModels(IEnumerable <IPredictor> models) { _host.CheckValue(models, nameof(models)); var ensemble = new InternalTreeEnsemble(); int modelCount = 0; int featureCount = -1; bool binaryClassifier = false; foreach (var model in models) { modelCount++; var predictor = model; _host.CheckValue(predictor, nameof(models), "One of the models is null"); var calibrated = predictor as IWeaklyTypedCalibratedModelParameters; double paramA = 1; if (calibrated != null) { _host.Check(calibrated.WeeklyTypedCalibrator is PlattCalibrator, "Combining FastTree models can only be done when the models are calibrated with Platt calibrator"); } predictor = calibrated.WeeklyTypedSubModel; paramA = -((PlattCalibrator)calibrated.WeeklyTypedCalibrator).Slope; var tree = predictor as TreeEnsembleModelParameters; if (tree == null) { throw _host.Except("Model is not a tree ensemble"); } foreach (var t in tree.TrainedEnsemble.Trees) { var bytes = new byte[t.SizeInBytes()]; int position = -1; t.ToByteArray(bytes, ref position); position = -1; var tNew = new InternalRegressionTree(bytes, ref position); if (paramA != 1) { for (int i = 0; i < tNew.NumLeaves; i++) { tNew.SetOutput(i, tNew.LeafValues[i] * paramA); } } ensemble.AddTree(tNew); } if (modelCount == 1) { binaryClassifier = calibrated != null; featureCount = tree.InputType.GetValueCount(); } else { _host.Check((calibrated != null) == binaryClassifier, "Ensemble contains both calibrated and uncalibrated models"); _host.Check(featureCount == tree.InputType.GetValueCount(), "Found models with different number of features"); } } var scale = 1 / (double)modelCount; foreach (var t in ensemble.Trees) { for (int i = 0; i < t.NumLeaves; i++) { t.SetOutput(i, t.LeafValues[i] * scale); } } switch (_kind) { case PredictionKind.BinaryClassification: if (!binaryClassifier) { return(new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null)); } var cali = new PlattCalibrator(_host, -1, 0); var fastTreeModel = new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null); return(new FeatureWeightsCalibratedModelParameters <FastTreeBinaryModelParameters, PlattCalibrator>(_host, fastTreeModel, cali)); case PredictionKind.Regression: return(new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null)); case PredictionKind.Ranking: return(new FastTreeRankingModelParameters(_host, ensemble, featureCount, null)); default: _host.Assert(false); throw _host.ExceptNotSupp(); } }