Ejemplo n.º 1
0
        internal override InternalRegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
        {
            Contracts.CheckValue(ch, nameof(ch));
            // Fit a regression tree to the gradient using least squares.
            InternalRegressionTree tree = TreeLearner.FitTargets(ch, activeFeatures, AdjustTargetsAndSetWeights(ch));

            if (tree == null)
            {
                return(null); // Could not learn a tree. Exit.
            }
            // Adjust output values of tree by performing a Newton step.

            // REVIEW: This should be part of OptimizingAlgorithm.
            using (Timer.Time(TimerEvent.TreeLearnerAdjustTreeOutputs))
            {
                double[] backupScores = null;
                // when doing dropouts we need to replace the TrainingScores with the scores without the dropped trees
                if (DropoutRate > 0)
                {
                    backupScores          = TrainingScores.Scores;
                    TrainingScores.Scores = _scores;
                }

                if (AdjustTreeOutputsOverride != null)
                {
                    AdjustTreeOutputsOverride.AdjustTreeOutputs(ch, tree, TreeLearner.Partitioning, TrainingScores);
                }
                else if (ObjectiveFunction is IStepSearch)
                {
                    (ObjectiveFunction as IStepSearch).AdjustTreeOutputs(ch, tree, TreeLearner.Partitioning, TrainingScores);
                }
                else
                {
                    throw ch.Except("No AdjustTreeOutputs defined. Objective function should define IStepSearch or AdjustTreeOutputsOverride should be set");
                }
                if (DropoutRate > 0)
                {
                    // Returning the original scores.
                    TrainingScores.Scores = backupScores;
                }
            }
            if (Smoothing != 0.0)
            {
                SmoothTree(tree, Smoothing);
                UseFastTrainingScoresUpdate = false;
            }
            if (DropoutRate > 0)
            {
                // Don't do shrinkage if you do dropouts.
                double scaling = (1.0 / (1.0 + _numberOfDroppedTrees));
                tree.ScaleOutputsBy(scaling);
                _treeScores.Add(tree.GetOutputs(TrainingScores.Dataset));
            }
            UpdateAllScores(ch, tree);
            Ensemble.AddTree(tree);
            return(tree);
        }
Ejemplo n.º 2
0
        private Ensemble GetBinaryEnsemble(int classID)
        {
            var      numClass = Objective.NumClass;
            Ensemble res      = new Ensemble();

            for (int i = classID; i < TrainedEnsemble.NumTrees; i += numClass)
            {
                res.AddTree(TrainedEnsemble.GetTreeAt(i));
            }
            return(res);
        }
        private Ensemble GetBinaryEnsemble(int classID)
        {
            var res = new Ensemble();

            for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numClass)
            {
                // Ignore dummy trees.
                if (TrainedEnsemble.GetTreeAt(i).NumLeaves > 1)
                {
                    res.AddTree(TrainedEnsemble.GetTreeAt(i));
                }
            }
            return(res);
        }
Ejemplo n.º 4
0
        internal override InternalRegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
        {
            Contracts.CheckValue(ch, nameof(ch));

            double[] sampleWeights      = null;
            double[] targets            = GetGradient(ch);
            double[] weightedTargets    = _gradientWrapper.AdjustTargetAndSetWeights(targets, ObjectiveFunction, out sampleWeights);
            InternalRegressionTree tree = ((RandomForestLeastSquaresTreeLearner)TreeLearner).FitTargets(ch, activeFeatures, weightedTargets,
                                                                                                        targets, sampleWeights);

            if (tree != null)
            {
                Ensemble.AddTree(tree);
            }
            return(tree);
        }
        public Ensemble GetModel(int[] categoricalFeatureBoudaries)
        {
            Ensemble res         = new Ensemble();
            string   modelString = GetModelString();

            string[] lines = modelString.Split('\n');
            int      i     = 0;

            for (; i < lines.Length;)
            {
                if (lines[i].StartsWith("Tree="))
                {
                    Dictionary <string, string> kvPairs = new Dictionary <string, string>();
                    ++i;
                    while (!lines[i].StartsWith("Tree=") && lines[i].Trim().Length != 0)
                    {
                        string[] kv = lines[i].Split('=');
                        Contracts.Check(kv.Length == 2);
                        kvPairs[kv[0].Trim()] = kv[1].Trim();
                        ++i;
                    }
                    int numLeaves = int.Parse(kvPairs["num_leaves"]);
                    int numCat    = int.Parse(kvPairs["num_cat"]);
                    if (numLeaves > 1)
                    {
                        var leftChild                = Str2IntArray(kvPairs["left_child"], ' ');
                        var rightChild               = Str2IntArray(kvPairs["right_child"], ' ');
                        var splitFeature             = Str2IntArray(kvPairs["split_feature"], ' ');
                        var threshold                = Str2DoubleArray(kvPairs["threshold"], ' ');
                        var splitGain                = Str2DoubleArray(kvPairs["split_gain"], ' ');
                        var leafOutput               = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        var decisionType             = Str2UIntArray(kvPairs["decision_type"], ' ');
                        var defaultValue             = GetDefalutValue(threshold, decisionType);
                        var categoricalSplitFeatures = new int[numLeaves - 1][];
                        var categoricalSplit         = new bool[numLeaves - 1];
                        if (categoricalFeatureBoudaries != null)
                        {
                            // Add offsets to split features.
                            for (int node = 0; node < numLeaves - 1; ++node)
                            {
                                splitFeature[node] = categoricalFeatureBoudaries[splitFeature[node]];
                            }
                        }

                        if (numCat > 0)
                        {
                            var catBoundaries = Str2IntArray(kvPairs["cat_boundaries"], ' ');
                            var catThreshold  = Str2UIntArray(kvPairs["cat_threshold"], ' ');
                            for (int node = 0; node < numLeaves - 1; ++node)
                            {
                                if (GetIsCategoricalSplit(decisionType[node]))
                                {
                                    int catIdx = (int)threshold[node];
                                    var cats   = GetCatThresholds(catThreshold, catBoundaries[catIdx], catBoundaries[catIdx + 1]);
                                    categoricalSplitFeatures[node] = new int[cats.Length];
                                    // Convert Cat thresholds to feature indices.
                                    for (int j = 0; j < cats.Length; ++j)
                                    {
                                        categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j] - 1;
                                    }

                                    splitFeature[node]     = -1;
                                    categoricalSplit[node] = true;
                                    // Swap left and right child.
                                    int t = leftChild[node];
                                    leftChild[node]  = rightChild[node];
                                    rightChild[node] = t;
                                }
                                else
                                {
                                    categoricalSplit[node] = false;
                                }
                            }
                        }
                        RegressionTree tree = RegressionTree.Create(numLeaves, splitFeature, splitGain,
                                                                    threshold.Select(x => (float)(x)).ToArray(), defaultValue.Select(x => (float)(x)).ToArray(), leftChild, rightChild, leafOutput,
                                                                    categoricalSplitFeatures, categoricalSplit);
                        res.AddTree(tree);
                    }
                    else
                    {
                        RegressionTree tree       = new RegressionTree(2);
                        var            leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        if (leafOutput[0] != 0)
                        {
                            // Convert Constant tree to Two-leaf tree, avoid being filter by TLC.
                            var categoricalSplitFeatures = new int[1][];
                            var categoricalSplit         = new bool[1];
                            tree = RegressionTree.Create(2, new int[] { 0 }, new double[] { 0 },
                                                         new float[] { 0 }, new float[] { 0 }, new int[] { -1 }, new int[] { -2 }, new double[] { leafOutput[0], leafOutput[0] },
                                                         categoricalSplitFeatures, categoricalSplit);
                        }
                        res.AddTree(tree);
                    }
                }
                else
                {
                    ++i;
                }
            }
            return(res);
        }