private TreeEnsemble GetBinaryEnsemble(int classID)
 {
     var res = new TreeEnsemble();
     for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numClass)
     {
         // Ignore dummy trees.
         if (TrainedEnsemble.GetTreeAt(i).NumLeaves > 1)
             res.AddTree(TrainedEnsemble.GetTreeAt(i));
     }
     return res;
 }
        private void TrainCore(IChannel ch, IProgressChannel pch, Dataset dtrain, CategoricalMetaData catMetaData, Dataset dvalid = null)
        {
            Host.AssertValue(ch);
            Host.AssertValue(pch);
            Host.AssertValue(dtrain);
            Host.AssertValueOrNull(dvalid);
            // For multi class, the number of labels is required.
            ch.Assert(PredictionKind != PredictionKind.MultiClassClassification || Options.ContainsKey("num_class"),
                      "LightGBM requires the number of classes to be specified in the parameters.");

            // Only enable one trainer to run at one time.
            lock (LightGbmShared.LockForMultiThreadingInside)
            {
                ch.Info("LightGBM objective={0}", Options["objective"]);
                using (Booster bst = WrappedLightGbmTraining.Train(ch, pch, Options, dtrain,
                                                                   dvalid: dvalid, numIteration: Args.NumBoostRound,
                                                                   verboseEval: Args.VerboseEval, earlyStoppingRound: Args.EarlyStoppingRound))
                {
                    TrainedEnsemble = bst.GetModel(catMetaData.CategoricalBoudaries);
                }
            }
        }
Exemple #3
0
 internal LightGbmRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
     : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
 {
 }
Exemple #4
0
 public LightGbmRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
     : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
 {
 }
Exemple #5
0
        public TreeEnsemble GetModel(int[] categoricalFeatureBoudaries)
        {
            TreeEnsemble res         = new TreeEnsemble();
            string       modelString = GetModelString();

            string[] lines = modelString.Split('\n');
            int      i     = 0;

            for (; i < lines.Length;)
            {
                if (lines[i].StartsWith("Tree="))
                {
                    Dictionary <string, string> kvPairs = new Dictionary <string, string>();
                    ++i;
                    while (!lines[i].StartsWith("Tree=") && lines[i].Trim().Length != 0)
                    {
                        string[] kv = lines[i].Split('=');
                        Contracts.Check(kv.Length == 2);
                        kvPairs[kv[0].Trim()] = kv[1].Trim();
                        ++i;
                    }
                    int numLeaves = int.Parse(kvPairs["num_leaves"]);
                    int numCat    = int.Parse(kvPairs["num_cat"]);
                    if (numLeaves > 1)
                    {
                        var leftChild                = Str2IntArray(kvPairs["left_child"], ' ');
                        var rightChild               = Str2IntArray(kvPairs["right_child"], ' ');
                        var splitFeature             = Str2IntArray(kvPairs["split_feature"], ' ');
                        var threshold                = Str2DoubleArray(kvPairs["threshold"], ' ');
                        var splitGain                = Str2DoubleArray(kvPairs["split_gain"], ' ');
                        var leafOutput               = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        var decisionType             = Str2UIntArray(kvPairs["decision_type"], ' ');
                        var defaultValue             = GetDefalutValue(threshold, decisionType);
                        var categoricalSplitFeatures = new int[numLeaves - 1][];
                        var categoricalSplit         = new bool[numLeaves - 1];
                        if (categoricalFeatureBoudaries != null)
                        {
                            // Add offsets to split features.
                            for (int node = 0; node < numLeaves - 1; ++node)
                            {
                                splitFeature[node] = categoricalFeatureBoudaries[splitFeature[node]];
                            }
                        }

                        if (numCat > 0)
                        {
                            var catBoundaries = Str2IntArray(kvPairs["cat_boundaries"], ' ');
                            var catThreshold  = Str2UIntArray(kvPairs["cat_threshold"], ' ');
                            for (int node = 0; node < numLeaves - 1; ++node)
                            {
                                if (GetIsCategoricalSplit(decisionType[node]))
                                {
                                    int catIdx = (int)threshold[node];
                                    var cats   = GetCatThresholds(catThreshold, catBoundaries[catIdx], catBoundaries[catIdx + 1]);
                                    categoricalSplitFeatures[node] = new int[cats.Length];
                                    // Convert Cat thresholds to feature indices.
                                    for (int j = 0; j < cats.Length; ++j)
                                    {
                                        categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j] - 1;
                                    }

                                    splitFeature[node]     = -1;
                                    categoricalSplit[node] = true;
                                    // Swap left and right child.
                                    int t = leftChild[node];
                                    leftChild[node]  = rightChild[node];
                                    rightChild[node] = t;
                                }
                                else
                                {
                                    categoricalSplit[node] = false;
                                }
                            }
                        }
                        RegressionTree tree = RegressionTree.Create(numLeaves, splitFeature, splitGain,
                                                                    threshold.Select(x => (float)(x)).ToArray(), defaultValue.Select(x => (float)(x)).ToArray(), leftChild, rightChild, leafOutput,
                                                                    categoricalSplitFeatures, categoricalSplit);
                        res.AddTree(tree);
                    }
                    else
                    {
                        RegressionTree tree       = new RegressionTree(2);
                        var            leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        if (leafOutput[0] != 0)
                        {
                            // Convert Constant tree to Two-leaf tree, avoid being filter by TLC.
                            var categoricalSplitFeatures = new int[1][];
                            var categoricalSplit         = new bool[1];
                            tree = RegressionTree.Create(2, new int[] { 0 }, new double[] { 0 },
                                                         new float[] { 0 }, new float[] { 0 }, new int[] { -1 }, new int[] { -2 }, new double[] { leafOutput[0], leafOutput[0] },
                                                         categoricalSplitFeatures, categoricalSplit);
                        }
                        res.AddTree(tree);
                    }
                }
                else
                {
                    ++i;
                }
            }
            return(res);
        }
Exemple #6
0
 public ConjugateGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
     : base(ensemble, trainData, initTrainScores, gradientWrapper)
 {
     _currentDk = new double[trainData.NumDocs];
 }
Exemple #7
0
 // REVIEW: When the FastTree appliation is decoupled with tree learner and boosting logic, this class should be removed.
 public RandomForestOptimizer(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
     : base(ensemble, trainData, initTrainScores, gradientWrapper)
 {
     _gradientWrapper = gradientWrapper;
 }
Exemple #8
0
 public AcceleratedGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
     : base(ensemble, trainData, initTrainScores, gradientWrapper)
 {
     UseFastTrainingScoresUpdate = false;
 }