예제 #1
0
        public (Tree.Ensemble, Parameters) GetModel()
        {
            Tree.Ensemble res         = new Tree.Ensemble();
            string        modelString = GetModelString();

            string[] lines = modelString.Split('\n');
            var      prms  = new Dictionary <string, string>();
            int      i     = 0;

            for (; i < lines.Length;)
            {
                if (lines[i].StartsWith("Tree="))
                {
                    Dictionary <string, string> kvPairs = new Dictionary <string, string>();
                    ++i;
                    while (!lines[i].StartsWith("Tree=") && lines[i].Trim().Length != 0)
                    {
                        string[] kv = lines[i].Split('=');
                        if (kv.Length != 2)
                        {
                            throw new FormatException();
                        }
                        kvPairs[kv[0].Trim()] = kv[1].Trim();
                        ++i;
                    }
                    int numLeaves = int.Parse(kvPairs["num_leaves"]);
                    int numCat    = int.Parse(kvPairs["num_cat"]);
                    if (numLeaves > 1)
                    {
                        var leftChild    = Str2IntArray(kvPairs["left_child"], ' ');
                        var rightChild   = Str2IntArray(kvPairs["right_child"], ' ');
                        var splitFeature = Str2IntArray(kvPairs["split_feature"], ' ');
                        var threshold    = Str2DoubleArray(kvPairs["threshold"], ' ');
                        var splitGain    = Str2DoubleArray(kvPairs["split_gain"], ' ');
                        var leafOutput   = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        var decisionType = Str2UIntArray(kvPairs["decision_type"], ' ');

                        for (var j = 0; j < threshold.Length; j++)
                        {
                            // See 'AvoidInf' in lightgbm source
                            var t = threshold[j];
                            if (t == 1e300)
                            {
                                threshold[j] = double.PositiveInfinity;
                            }
                            else if (t == -1e300)
                            {
                                threshold[j] = double.NegativeInfinity;
                            }
                        }

                        var defaultValue = GetDefaultValue(threshold, decisionType);

                        var categoricalSplit = new bool[numLeaves - 1];
                        var catBoundaries    = Array.Empty <int>();
                        var catThresholds    = Array.Empty <uint>();
                        if (numCat > 0)
                        {
                            catBoundaries = Str2IntArray(kvPairs["cat_boundaries"], ' ');
                            catThresholds = Str2UIntArray(kvPairs["cat_threshold"], ' ');
                            for (int node = 0; node < numLeaves - 1; ++node)
                            {
                                categoricalSplit[node] = GetIsCategoricalSplit(decisionType[node]);
                            }
                        }

                        var tree = Tree.RegressionTree.Create(
                            numLeaves,
                            splitFeature,
                            splitGain,
                            threshold,
                            defaultValue,
                            leftChild,
                            rightChild,
                            leafOutput,
                            catBoundaries,
                            catThresholds,
                            categoricalSplit);
                        res.AddTree(tree);
                    }
                    else
                    {
                        // always need to add tree, otherwise multiclass will be wrong
                        var leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        var tree       = Tree.RegressionTree.Create(
                            2,
                            new int[] { 0 },
                            new double[] { 0 },
                            new double[] { 0 },
                            new double[] { 0 },
                            new int[] { -1 },
                            new int[] { -2 },
                            new double[] { leafOutput[0], leafOutput[0] },
                            new int[] { },
                            new uint[] { },
                            new bool[] { false });
                        res.AddTree(tree);
                    }
                }
                else
                {
                    // [objective: binary]
                    if (lines[i].StartsWith("["))
                    {
                        var bits = lines[i].Split(new char[] { '[', ']', ' ', ':' }, StringSplitOptions.RemoveEmptyEntries);
                        if (bits.Length == 2)   // ignores, e.g. [data: ]
                        {
                            prms.Add(bits[0], bits[1]);
                        }
                    }
                    ++i;
                }
            }

            // extract parameters
            var p = new Parameters {
                Common    = _helperCommon.FromParameters(prms),
                Dataset   = _helperDataset.FromParameters(prms),
                Objective = _helperObjective.FromParameters(prms),
                Learning  = _helperLearning.FromParameters(prms)
            };

            return(res, p);
        }
예제 #2
0
        public (Tree.Ensemble, Parameters) GetModel()
        {
            Tree.Ensemble res = new Tree.Ensemble();
            string modelString = GetModelString();
            string[] lines = modelString.Split('\n');
            var prms = new Dictionary<string, string>();
            var delimiters = new char[] { ' ' };
            int i = 0;
            for (; i < lines.Length;)
            {
                if (lines[i].StartsWith("Tree="))
                {
                    Dictionary<string, string> kvPairs = new Dictionary<string, string>();
                    ++i;
                    while (!lines[i].StartsWith("Tree=") && lines[i].Trim().Length != 0)
                    {
                        string[] kv = lines[i].Split('=');
                        if (kv.Length != 2) throw new FormatException();
                        kvPairs[kv[0].Trim()] = kv[1].Trim();
                        ++i;
                    }
                    int numLeaves = int.Parse(kvPairs["num_leaves"]);
                    int numCat = int.Parse(kvPairs["num_cat"]);
                    var leftChild = Str2IntArray(kvPairs["left_child"], delimiters);
                    var rightChild = Str2IntArray(kvPairs["right_child"], delimiters);
                    var splitFeature = Str2IntArray(kvPairs["split_feature"], delimiters);
                    var threshold = Str2DoubleArray(kvPairs["threshold"], delimiters);
                    var splitGain = Str2DoubleArray(kvPairs["split_gain"], delimiters);
                    var leafOutput = Str2DoubleArray(kvPairs["leaf_value"], delimiters);
                    var decisionType = Str2UIntArray(kvPairs["decision_type"], delimiters);

                    for (var j = 0; j < threshold.Length; j++)
                    {
                        // See 'AvoidInf' in lightgbm source
                        var t = threshold[j];
                        if (t == 1e300)
                            threshold[j] = double.PositiveInfinity;
                        else if (t == -1e300)
                            threshold[j] = double.NegativeInfinity;
                    }

                    var defaultValue = GetDefaultValue(threshold, decisionType);

                    var categoricalSplit = new bool[numLeaves - 1];
                    var catBoundaries = Array.Empty<int>();
                    var catThresholds = Array.Empty<uint>();
                    if (numCat > 0)
                    {
                        catBoundaries = Str2IntArray(kvPairs["cat_boundaries"], delimiters);
                        catThresholds = Str2UIntArray(kvPairs["cat_threshold"], delimiters);
                        for (int node = 0; node < numLeaves - 1; ++node)
                        {
                            categoricalSplit[node] = GetIsCategoricalSplit(decisionType[node]);
                        }
                    }

                    double[] leafConst = null;
                    int[][] leafFeaturesUnpacked = null;
                    double[][] leafCoeffUnpacked = null;

                    var isLinear = Int32.Parse(kvPairs["is_linear"]) > 0;
                    if (isLinear)
                    {
                        leafConst = Str2DoubleArray(kvPairs["leaf_const"], delimiters);
                        var numFeatures = Str2IntArray(kvPairs["num_features"], delimiters);
                        var leafFeatures = Str2IntArray(kvPairs["leaf_features"], delimiters);
                        var leafCoeff = Str2DoubleArray(kvPairs["leaf_coeff"], delimiters);

                        leafFeaturesUnpacked = new int[numFeatures.Length][];
                        leafCoeffUnpacked = new double[numFeatures.Length][];
                        var idx = 0;
                        for (var j=0; j < numFeatures.Length; j++)
                        {
                            var len = numFeatures[j];
                            leafFeaturesUnpacked[j] = new int[len];
                            leafCoeffUnpacked[j] = new double[len];
                            for (var k = 0; k < len; k++)
                            {
                                leafFeaturesUnpacked[j][k] = leafFeatures[idx];
                                leafCoeffUnpacked[j][k] = leafCoeff[idx];
                                idx++;
                            }
                        }
                        if (idx != leafFeatures.Length)
                            throw new Exception("Failed to parse leaf features");
                    }

                    var tree = Tree.RegressionTree.Create(
                                    numLeaves,
                                    splitFeature,
                                    splitGain,
                                    threshold,
                                    defaultValue,
                                    leftChild,
                                    rightChild,
                                    leafOutput,
                                    catBoundaries,
                                    catThresholds,
                                    categoricalSplit,
                                    isLinear,
                                    leafConst,
                                    leafFeaturesUnpacked,
                                    leafCoeffUnpacked);
                    res.AddTree(tree);
                }
                else
                {
                    // [objective: binary]
                    if (lines[i].StartsWith("["))
                    {
                        var bits = lines[i].Split(new char[] { '[', ']', ' ', ':' }, StringSplitOptions.RemoveEmptyEntries);
                        if (bits.Length == 2)   // ignores, e.g. [data: ]
                            prms.Add(bits[0], bits[1]);
                    }
                    ++i;
                }
            }

            // extract parameters
            var p = new Parameters {
                Common = _helperCommon.FromParameters(prms),
                Dataset = _helperDataset.FromParameters(prms),
                Objective = _helperObjective.FromParameters(prms),
                Learning = _helperLearning.FromParameters(prms)
                };

            // irrelevant parameter for managed trees which always use NaN for missing value
            prms.Remove("zero_as_missing");
            prms.Remove("saved_feature_importance_type");
            if (prms.Count > 0)
            {
                Console.WriteLine($"WARNING: Unknown new parameters {String.Join(",", prms.Keys)}");
            }
            
            return (res, p);
        }