예제 #1
0
        private protected override IPredictorWithFeatureWeights <float> TrainModelCore(TrainContext context)
        {
            Host.CheckValue(context, nameof(context));
            var trainData = context.TrainingSet;

            ValidData = context.ValidationSet;
            TestData  = context.TestSet;

            using (var ch = Host.Start("Training"))
            {
                ch.CheckValue(trainData, nameof(trainData));
                trainData.CheckBinaryLabel();
                trainData.CheckFeatureFloatVector();
                trainData.CheckOptFloatWeight();
                FeatureCount = trainData.Schema.Feature.Type.ValueCount;
                ConvertData(trainData);
                TrainCore(ch);
            }

            // The FastTree binary classification boosting is naturally calibrated to
            // output probabilities when transformed using a scaled logistic function,
            // so transform the scores using that.

            var pred = new FastTreeBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs);
            // FastTree's binary classification boosting framework's natural probabilistic interpretation
            // is explained in "From RankNet to LambdaRank to LambdaMART: An Overview" by Chris Burges.
            // The correctness of this scaling depends upon the gradient calculation in
            // BinaryClassificationObjectiveFunction.GetGradientInOneQuery being consistent with the
            // description in section 6 of the paper.
            var cali = new PlattCalibrator(Host, -1 * _sigmoidParameter, 0);

            return(new FeatureWeightsCalibratedPredictor(Host, pred, cali));
        }
예제 #2
0
        private static IPredictorProducing <float> Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            var         predictor = new FastTreeBinaryModelParameters(env, ctx);
            ICalibrator calibrator;

            ctx.LoadModelOrNull <ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
            if (calibrator == null)
            {
                return(predictor);
            }
            return(new SchemaBindableCalibratedModelParameters <FastTreeBinaryModelParameters, ICalibrator>(env, predictor, calibrator));
        }
예제 #3
0
        IPredictor IModelCombiner.CombineModels(IEnumerable <IPredictor> models)
        {
            _host.CheckValue(models, nameof(models));

            var  ensemble         = new InternalTreeEnsemble();
            int  modelCount       = 0;
            int  featureCount     = -1;
            bool binaryClassifier = false;

            foreach (var model in models)
            {
                modelCount++;

                var predictor = model;
                _host.CheckValue(predictor, nameof(models), "One of the models is null");

                var    calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
                double paramA     = 1;
                if (calibrated != null)
                {
                    _host.Check(calibrated.WeeklyTypedCalibrator is PlattCalibrator,
                                "Combining FastTree models can only be done when the models are calibrated with Platt calibrator");
                }

                predictor = calibrated.WeeklyTypedSubModel;
                paramA    = -((PlattCalibrator)calibrated.WeeklyTypedCalibrator).Slope;

                var tree = predictor as TreeEnsembleModelParameters;

                if (tree == null)
                {
                    throw _host.Except("Model is not a tree ensemble");
                }
                foreach (var t in tree.TrainedEnsemble.Trees)
                {
                    var bytes    = new byte[t.SizeInBytes()];
                    int position = -1;
                    t.ToByteArray(bytes, ref position);
                    position = -1;
                    var tNew = new InternalRegressionTree(bytes, ref position);
                    if (paramA != 1)
                    {
                        for (int i = 0; i < tNew.NumLeaves; i++)
                        {
                            tNew.SetOutput(i, tNew.LeafValues[i] * paramA);
                        }
                    }
                    ensemble.AddTree(tNew);
                }

                if (modelCount == 1)
                {
                    binaryClassifier = calibrated != null;
                    featureCount     = tree.InputType.GetValueCount();
                }
                else
                {
                    _host.Check((calibrated != null) == binaryClassifier, "Ensemble contains both calibrated and uncalibrated models");
                    _host.Check(featureCount == tree.InputType.GetValueCount(), "Found models with different number of features");
                }
            }

            var scale = 1 / (double)modelCount;

            foreach (var t in ensemble.Trees)
            {
                for (int i = 0; i < t.NumLeaves; i++)
                {
                    t.SetOutput(i, t.LeafValues[i] * scale);
                }
            }

            switch (_kind)
            {
            case PredictionKind.BinaryClassification:
                if (!binaryClassifier)
                {
                    return(new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null));
                }

                var cali          = new PlattCalibrator(_host, -1, 0);
                var fastTreeModel = new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null);
                return(new FeatureWeightsCalibratedModelParameters <FastTreeBinaryModelParameters, PlattCalibrator>(_host, fastTreeModel, cali));

            case PredictionKind.Regression:
                return(new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null));

            case PredictionKind.Ranking:
                return(new FastTreeRankingModelParameters(_host, ensemble, featureCount, null));

            default:
                _host.Assert(false);
                throw _host.ExceptNotSupp();
            }
        }