public static XGBoost ReorderXGBoost(XGBoost model, short[] reorderMapping)
        {
            var reorderedTrees = model.Trees
                                 .Select(m => ReorderTree(m, reorderMapping))
                                 .ToArray();
            var results = new XGBoost(reorderedTrees);

            return(results);
        }
예제 #2
0
        public static XGBoost Create(string allTrees, bool prepareCompiled = false, bool prepareFlat = false, bool prepareShrink = false)
        {
            var treeStrings = treeSplit.Split(allTrees);
            var trees       = treeStrings
                              .Where(m => !string.IsNullOrWhiteSpace(m))
                              .Select(DecisionTree.Create)
                              .ToArray();
            var model = new XGBoost(trees);

            if (prepareCompiled)
            {
                model.PrepareCompiled();
            }
            if (prepareFlat)
            {
                model.PrepareFlat();
            }
            if (prepareShrink)
            {
                model.PrepareShrink(634);                // HACK remove this constant
            }
            return(model);
        }
        public static short[] Greedy(XGBoost model, int treeCount)
        {
            Console.WriteLine("doing greedy search");
            var allPaths    = model.Trees.Take(treeCount).SelectMany(GetAllPaths).ToArray();
            var nodeLookups = allPaths
                              .SelectMany(m => m)
                              .Distinct()
                              .Where(m => m.FeatureIndex >= 0)
                              .GroupBy(m => m.FeatureIndex)
                              .ToDictionary(m => m.Key, m => m.ToArray());
            var featuresMax = allPaths.SelectMany(m => m.Select(j => j.FeatureIndex)).Max();

            for (short i = 0; (int)i <= featuresMax; i++)
            {
                if (!nodeLookups.ContainsKey(i))
                {
                    nodeLookups[i] = new[] { new DecisionTreeNode2 {
                                                 FeatureIndex = i, OriginalIndex = i
                                             } };
                }
            }
            double    best                = EvalAveragePages(allPaths);
            int       sinceImprovement    = 0;
            const int maxImprovementDelay = 25_000;

            while (true)
            {
                Console.WriteLine(" loop");
                bool improved = false;
                for (int i = 0; i < featuresMax; i++)
                {
                    short ii = (short)i;
                    for (int j = 0; j < featuresMax; j++)
                    {
                        if (i == j)
                        {
                            continue;
                        }
                        short jj     = (short)j;
                        var   nodesI = nodeLookups[ii];
                        var   nodesJ = nodeLookups[jj];
                        // do swap
                        for (int x = 0; x < nodesI.Length; x++)
                        {
                            nodesI[x].FeatureIndex = jj;
                        }
                        for (int x = 0; x < nodesJ.Length; x++)
                        {
                            nodesJ[x].FeatureIndex = ii;
                        }
                        // eval score
                        double possible = EvalAveragePages(allPaths);
                        //Console.WriteLine($"  {i}, {j}, {best}, {possible}");
                        // set best if necessary
                        if (possible < best)
                        {
                            best             = possible;
                            nodeLookups[ii]  = nodesJ;
                            nodeLookups[jj]  = nodesI;
                            improved         = true;
                            sinceImprovement = 0;
                            UpdateBestMap(nodeLookups);
                            Console.WriteLine($" new best: {best}");
                        }
                        else
                        {
                            sinceImprovement++;
                            // else undo that shit
                            for (int x = 0; x < nodesI.Length; x++)
                            {
                                nodesI[x].FeatureIndex = ii;
                            }
                            for (int x = 0; x < nodesJ.Length; x++)
                            {
                                nodesJ[x].FeatureIndex = jj;
                            }
                        }
                        if (sinceImprovement > maxImprovementDelay)
                        {
                            break;
                        }
                    }
                    if (sinceImprovement > maxImprovementDelay)
                    {
                        break;
                    }
                }
                if (!improved || sinceImprovement > maxImprovementDelay)
                {
                    break;
                }
            }
            UpdateBestMap(nodeLookups);
            return(bestMap);
        }
        // attempt to reorder features to improve cache locality
        public static int[] CreateFeatureReorderingMap(XGBoost model)
        {
            var allPaths = model.Trees.SelectMany(GetAllPaths).ToArray();

            throw new NotImplementedException();
        }