Ejemplo n.º 1
0
        public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
        {
            Contracts.CheckValue(ch, nameof(ch));
            // Fit a regression tree to the gradient using least squares.
            RegressionTree tree = TreeLearner.FitTargets(ch, activeFeatures, AdjustTargetsAndSetWeights(ch));

            if (tree == null)
            {
                return(null); // Could not learn a tree. Exit.
            }
            // Adjust output values of tree by performing a Newton step.

            // REVIEW: This should be part of OptimizingAlgorithm.
            using (Timer.Time(TimerEvent.TreeLearnerAdjustTreeOutputs))
            {
                double[] backupScores = null;
                // when doing dropouts we need to replace the TrainingScores with the scores without the dropped trees
                if (DropoutRate > 0)
                {
                    backupScores          = TrainingScores.Scores;
                    TrainingScores.Scores = _scores;
                }

                if (AdjustTreeOutputsOverride != null)
                {
                    AdjustTreeOutputsOverride.AdjustTreeOutputs(ch, tree, TreeLearner.Partitioning, TrainingScores);
                }
                else if (ObjectiveFunction is IStepSearch)
                {
                    (ObjectiveFunction as IStepSearch).AdjustTreeOutputs(ch, tree, TreeLearner.Partitioning, TrainingScores);
                }
                else
                {
                    throw ch.Except("No AdjustTreeOutputs defined. Objective function should define IStepSearch or AdjustTreeOutputsOverride should be set");
                }
                if (DropoutRate > 0)
                {
                    // Returning the original scores.
                    TrainingScores.Scores = backupScores;
                }
            }
            if (Smoothing != 0.0)
            {
                SmoothTree(tree, Smoothing);
                UseFastTrainingScoresUpdate = false;
            }
            if (DropoutRate > 0)
            {
                // Don't do shrinkage if you do dropouts.
                double scaling = (1.0 / (1.0 + _numberOfDroppedTrees));
                tree.ScaleOutputsBy(scaling);
                _treeScores.Add(tree.GetOutputs(TrainingScores.Dataset));
            }
            UpdateAllScores(ch, tree);
            Ensemble.AddTree(tree);
            return(tree);
        }
Ejemplo n.º 2
0
        private Ensemble GetEnsembleFromSolution(LassoFit fit, int solutionIdx, Ensemble originalEnsemble)
        {
            Ensemble ensemble = new Ensemble();

            int weightsCount = fit.NumberOfWeights[solutionIdx];

            for (int i = 0; i < weightsCount; i++)
            {
                double weight = fit.CompressedWeights[solutionIdx][i];
                if (weight != 0)
                {
                    RegressionTree tree = originalEnsemble.GetTreeAt(fit.Indices[i]);
                    tree.Weight = weight;
                    ensemble.AddTree(tree);
                }
            }

            ensemble.Bias = fit.Intercepts[solutionIdx];
            return(ensemble);
        }
Ejemplo n.º 3
0
        public IPredictorProducing <float> CombineModels(IEnumerable <IPredictorProducing <float> > models)
        {
            _host.CheckValue(models, nameof(models));

            var  ensemble         = new Ensemble();
            int  modelCount       = 0;
            int  featureCount     = -1;
            bool binaryClassifier = false;

            foreach (var model in models)
            {
                modelCount++;

                var predictor = model;
                _host.CheckValue(predictor, nameof(models), "One of the models is null");

                var    calibrated = predictor as CalibratedPredictorBase;
                double paramA     = 1;
                if (calibrated != null)
                {
                    _host.Check(calibrated.Calibrator is PlattCalibrator,
                                "Combining FastTree models can only be done when the models are calibrated with Platt calibrator");
                    predictor = calibrated.SubPredictor;
                    paramA    = -(calibrated.Calibrator as PlattCalibrator).ParamA;
                }
                var tree = predictor as FastTreePredictionWrapper;
                if (tree == null)
                {
                    throw _host.Except("Model is not a tree ensemble");
                }
                foreach (var t in tree.TrainedEnsemble.Trees)
                {
                    var bytes    = new byte[t.SizeInBytes()];
                    int position = -1;
                    t.ToByteArray(bytes, ref position);
                    position = -1;
                    var tNew = new RegressionTree(bytes, ref position);
                    if (paramA != 1)
                    {
                        for (int i = 0; i < tNew.NumLeaves; i++)
                        {
                            tNew.SetOutput(i, tNew.LeafValues[i] * paramA);
                        }
                    }
                    ensemble.AddTree(tNew);
                }

                if (modelCount == 1)
                {
                    binaryClassifier = calibrated != null;
                    featureCount     = tree.InputType.ValueCount;
                }
                else
                {
                    _host.Check((calibrated != null) == binaryClassifier, "Ensemble contains both calibrated and uncalibrated models");
                    _host.Check(featureCount == tree.InputType.ValueCount, "Found models with different number of features");
                }
            }

            var scale = 1 / (double)modelCount;

            foreach (var t in ensemble.Trees)
            {
                for (int i = 0; i < t.NumLeaves; i++)
                {
                    t.SetOutput(i, t.LeafValues[i] * scale);
                }
            }

            switch (_kind)
            {
            case PredictionKind.BinaryClassification:
                if (!binaryClassifier)
                {
                    return(new FastTreeBinaryPredictor(_host, ensemble, featureCount, null));
                }

                var cali = new PlattCalibrator(_host, -1, 0);
                return(new FeatureWeightsCalibratedPredictor(_host, new FastTreeBinaryPredictor(_host, ensemble, featureCount, null), cali));

            case PredictionKind.Regression:
                return(new FastTreeRegressionPredictor(_host, ensemble, featureCount, null));

            case PredictionKind.Ranking:
                return(new FastTreeRankingPredictor(_host, ensemble, featureCount, null));

            default:
                _host.Assert(false);
                throw _host.ExceptNotSupp();
            }
        }