Esempio n. 1
0
        private InternalTreeEnsemble GetBinaryEnsemble(int classID)
        {
            var res = new InternalTreeEnsemble();

            for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numClass)
            {
                // Ignore dummy trees.
                if (TrainedEnsemble.GetTreeAt(i).NumLeaves > 1)
                {
                    res.AddTree(TrainedEnsemble.GetTreeAt(i));
                }
            }
            return(res);
        }
        private void TrainCore(IChannel ch, IProgressChannel pch, Dataset dtrain, CategoricalMetaData catMetaData, Dataset dvalid = null)
        {
            Host.AssertValue(ch);
            Host.AssertValue(pch);
            Host.AssertValue(dtrain);
            Host.AssertValueOrNull(dvalid);
            // For multi class, the number of labels is required.
            ch.Assert(((ITrainer)this).PredictionKind != PredictionKind.MulticlassClassification || Options.ContainsKey("num_class"),
                      "LightGBM requires the number of classes to be specified in the parameters.");

            // Only enable one trainer to run at one time.
            lock (LightGbmShared.LockForMultiThreadingInside)
            {
                ch.Info("LightGBM objective={0}", Options["objective"]);
                using (Booster bst = WrappedLightGbmTraining.Train(ch, pch, Options, dtrain,
                                                                   dvalid: dvalid, numIteration: LightGbmTrainerOptions.NumberOfIterations,
                                                                   verboseEval: LightGbmTrainerOptions.Verbose, earlyStoppingRound: LightGbmTrainerOptions.EarlyStoppingRound))
                {
                    TrainedEnsemble = bst.GetModel(catMetaData.CategoricalBoudaries);
                }
            }
        }
Esempio n. 3
0
 internal LightGbmBinaryModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
     : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
 {
 }
Esempio n. 4
0
        public InternalTreeEnsemble GetModel(int[] categoricalFeatureBoudaries)
        {
            InternalTreeEnsemble res = new InternalTreeEnsemble();
            string modelString       = GetModelString();

            string[] lines = modelString.Split('\n');
            int      i     = 0;

            for (; i < lines.Length;)
            {
                if (lines[i].StartsWith("Tree="))
                {
                    Dictionary <string, string> kvPairs = new Dictionary <string, string>();
                    ++i;
                    while (!lines[i].StartsWith("Tree=") && lines[i].Trim().Length != 0)
                    {
                        string[] kv = lines[i].Split('=');
                        Contracts.Check(kv.Length == 2);
                        kvPairs[kv[0].Trim()] = kv[1].Trim();
                        ++i;
                    }
                    int numberOfLeaves = int.Parse(kvPairs["num_leaves"], CultureInfo.InvariantCulture);
                    int numCat         = int.Parse(kvPairs["num_cat"], CultureInfo.InvariantCulture);
                    if (numberOfLeaves > 1)
                    {
                        var leftChild                = Str2IntArray(kvPairs["left_child"], ' ');
                        var rightChild               = Str2IntArray(kvPairs["right_child"], ' ');
                        var splitFeature             = Str2IntArray(kvPairs["split_feature"], ' ');
                        var threshold                = Str2DoubleArray(kvPairs["threshold"], ' ');
                        var splitGain                = Str2DoubleArray(kvPairs["split_gain"], ' ');
                        var leafOutput               = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        var decisionType             = Str2UIntArray(kvPairs["decision_type"], ' ');
                        var defaultValue             = GetDefalutValue(threshold, decisionType);
                        var categoricalSplitFeatures = new int[numberOfLeaves - 1][];
                        var categoricalSplit         = new bool[numberOfLeaves - 1];
                        if (categoricalFeatureBoudaries != null)
                        {
                            // Add offsets to split features.
                            for (int node = 0; node < numberOfLeaves - 1; ++node)
                            {
                                splitFeature[node] = categoricalFeatureBoudaries[splitFeature[node]];
                            }
                        }

                        if (numCat > 0)
                        {
                            var catBoundaries = Str2IntArray(kvPairs["cat_boundaries"], ' ');
                            var catThreshold  = Str2UIntArray(kvPairs["cat_threshold"], ' ');
                            for (int node = 0; node < numberOfLeaves - 1; ++node)
                            {
                                if (GetIsCategoricalSplit(decisionType[node]))
                                {
                                    int catIdx = (int)threshold[node];
                                    var cats   = GetCatThresholds(catThreshold, catBoundaries[catIdx], catBoundaries[catIdx + 1]);
                                    categoricalSplitFeatures[node] = new int[cats.Length];
                                    // Convert Cat thresholds to feature indices.
                                    for (int j = 0; j < cats.Length; ++j)
                                    {
                                        categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j] - 1;
                                    }

                                    splitFeature[node]     = -1;
                                    categoricalSplit[node] = true;
                                    // Swap left and right child.
                                    int t = leftChild[node];
                                    leftChild[node]  = rightChild[node];
                                    rightChild[node] = t;
                                }
                                else
                                {
                                    categoricalSplit[node] = false;
                                }
                            }
                        }
                        InternalRegressionTree tree = InternalRegressionTree.Create(numberOfLeaves, splitFeature, splitGain,
                                                                                    threshold.Select(x => (float)(x)).ToArray(), defaultValue.Select(x => (float)(x)).ToArray(), leftChild, rightChild, leafOutput,
                                                                                    categoricalSplitFeatures, categoricalSplit);
                        res.AddTree(tree);
                    }
                    else
                    {
                        InternalRegressionTree tree = new InternalRegressionTree(2);
                        var leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        if (leafOutput[0] != 0)
                        {
                            // Convert Constant tree to Two-leaf tree, avoid being filter by TLC.
                            var categoricalSplitFeatures = new int[1][];
                            var categoricalSplit         = new bool[1];
                            tree = InternalRegressionTree.Create(2, new int[] { 0 }, new double[] { 0 },
                                                                 new float[] { 0 }, new float[] { 0 }, new int[] { -1 }, new int[] { -2 }, new double[] { leafOutput[0], leafOutput[0] },
                                                                 categoricalSplitFeatures, categoricalSplit);
                        }
                        res.AddTree(tree);
                    }
                }
                else
                {
                    ++i;
                }
            }
            return(res);
        }
Esempio n. 5
0
 // REVIEW: When the FastTree appliation is decoupled with tree learner and boosting logic, this class should be removed.
 internal RandomForestOptimizer(InternalTreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
     : base(ensemble, trainData, initTrainScores, gradientWrapper)
 {
     _gradientWrapper = gradientWrapper;
 }
Esempio n. 6
0
 public ConjugateGradientDescent(InternalTreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
     : base(ensemble, trainData, initTrainScores, gradientWrapper)
 {
     _currentDk = new double[trainData.NumDocs];
 }
Esempio n. 7
0
 internal AcceleratedGradientDescent(InternalTreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
     : base(ensemble, trainData, initTrainScores, gradientWrapper)
 {
     UseFastTrainingScoresUpdate = false;
 }