private TreeEnsemble GetBinaryEnsemble(int classID) { var res = new TreeEnsemble(); for (int i = classID; i < TrainedEnsemble.NumTrees; i += _numClass) { // Ignore dummy trees. if (TrainedEnsemble.GetTreeAt(i).NumLeaves > 1) res.AddTree(TrainedEnsemble.GetTreeAt(i)); } return res; }
private void TrainCore(IChannel ch, IProgressChannel pch, Dataset dtrain, CategoricalMetaData catMetaData, Dataset dvalid = null) { Host.AssertValue(ch); Host.AssertValue(pch); Host.AssertValue(dtrain); Host.AssertValueOrNull(dvalid); // For multi class, the number of labels is required. ch.Assert(PredictionKind != PredictionKind.MultiClassClassification || Options.ContainsKey("num_class"), "LightGBM requires the number of classes to be specified in the parameters."); // Only enable one trainer to run at one time. lock (LightGbmShared.LockForMultiThreadingInside) { ch.Info("LightGBM objective={0}", Options["objective"]); using (Booster bst = WrappedLightGbmTraining.Train(ch, pch, Options, dtrain, dvalid: dvalid, numIteration: Args.NumBoostRound, verboseEval: Args.VerboseEval, earlyStoppingRound: Args.EarlyStoppingRound)) { TrainedEnsemble = bst.GetModel(catMetaData.CategoricalBoudaries); } } }
internal LightGbmRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { }
public LightGbmRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { }
public TreeEnsemble GetModel(int[] categoricalFeatureBoudaries) { TreeEnsemble res = new TreeEnsemble(); string modelString = GetModelString(); string[] lines = modelString.Split('\n'); int i = 0; for (; i < lines.Length;) { if (lines[i].StartsWith("Tree=")) { Dictionary <string, string> kvPairs = new Dictionary <string, string>(); ++i; while (!lines[i].StartsWith("Tree=") && lines[i].Trim().Length != 0) { string[] kv = lines[i].Split('='); Contracts.Check(kv.Length == 2); kvPairs[kv[0].Trim()] = kv[1].Trim(); ++i; } int numLeaves = int.Parse(kvPairs["num_leaves"]); int numCat = int.Parse(kvPairs["num_cat"]); if (numLeaves > 1) { var leftChild = Str2IntArray(kvPairs["left_child"], ' '); var rightChild = Str2IntArray(kvPairs["right_child"], ' '); var splitFeature = Str2IntArray(kvPairs["split_feature"], ' '); var threshold = Str2DoubleArray(kvPairs["threshold"], ' '); var splitGain = Str2DoubleArray(kvPairs["split_gain"], ' '); var leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' '); var decisionType = Str2UIntArray(kvPairs["decision_type"], ' '); var defaultValue = GetDefalutValue(threshold, decisionType); var categoricalSplitFeatures = new int[numLeaves - 1][]; var categoricalSplit = new bool[numLeaves - 1]; if (categoricalFeatureBoudaries != null) { // Add offsets to split features. for (int node = 0; node < numLeaves - 1; ++node) { splitFeature[node] = categoricalFeatureBoudaries[splitFeature[node]]; } } if (numCat > 0) { var catBoundaries = Str2IntArray(kvPairs["cat_boundaries"], ' '); var catThreshold = Str2UIntArray(kvPairs["cat_threshold"], ' '); for (int node = 0; node < numLeaves - 1; ++node) { if (GetIsCategoricalSplit(decisionType[node])) { int catIdx = (int)threshold[node]; var cats = GetCatThresholds(catThreshold, catBoundaries[catIdx], catBoundaries[catIdx + 1]); categoricalSplitFeatures[node] = new int[cats.Length]; // Convert Cat thresholds to feature indices. for (int j = 0; j < cats.Length; ++j) { categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j] - 1; } splitFeature[node] = -1; categoricalSplit[node] = true; // Swap left and right child. int t = leftChild[node]; leftChild[node] = rightChild[node]; rightChild[node] = t; } else { categoricalSplit[node] = false; } } } RegressionTree tree = RegressionTree.Create(numLeaves, splitFeature, splitGain, threshold.Select(x => (float)(x)).ToArray(), defaultValue.Select(x => (float)(x)).ToArray(), leftChild, rightChild, leafOutput, categoricalSplitFeatures, categoricalSplit); res.AddTree(tree); } else { RegressionTree tree = new RegressionTree(2); var leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' '); if (leafOutput[0] != 0) { // Convert Constant tree to Two-leaf tree, avoid being filter by TLC. var categoricalSplitFeatures = new int[1][]; var categoricalSplit = new bool[1]; tree = RegressionTree.Create(2, new int[] { 0 }, new double[] { 0 }, new float[] { 0 }, new float[] { 0 }, new int[] { -1 }, new int[] { -2 }, new double[] { leafOutput[0], leafOutput[0] }, categoricalSplitFeatures, categoricalSplit); } res.AddTree(tree); } } else { ++i; } } return(res); }
public ConjugateGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper) : base(ensemble, trainData, initTrainScores, gradientWrapper) { _currentDk = new double[trainData.NumDocs]; }
// REVIEW: When the FastTree appliation is decoupled with tree learner and boosting logic, this class should be removed. public RandomForestOptimizer(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper) : base(ensemble, trainData, initTrainScores, gradientWrapper) { _gradientWrapper = gradientWrapper; }
public AcceleratedGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper) : base(ensemble, trainData, initTrainScores, gradientWrapper) { UseFastTrainingScoresUpdate = false; }