Exemplo n.º 1
0
        private void BuildRegressionTree(BoostTreeLoss boostTreeLoss, int iTree, IFindSplit findSplit, RandomSampler featureSampler, RandomSampler dataSampler)
        {
            this.responses = boostTreeLoss.PseudoResponse(iTree);

            TreeNode root = new TreeNode();
            root.isTerminal = true;
            root.dataPoints = Vector.IndexArray(this.workIndex.Length);

            this.tree = new TreeNode[2 * maxTreeSize - 1];
            this.tree[0] = root;

            for (int i = 0; i < maxTreeSize - 1; i++)
            {
                float maxGain = -1;
                int bestRegion = -1;

                TreeNode leftNode = new TreeNode();
                TreeNode rightNode = new TreeNode();

                //qiangwu: compute the best split for new nodes
                //         We only need to explore the last two nodes because they are and only they are new nodes i.e.
                //         for (int j = 2*i; j >= 0; j--)
                for (int j = 0; j < 2 * i + 1; j++)
                {
                    TreeNode curNode = this.tree[j];

                    //qiangwu: (assert curNode.split<0 && curNode.isTerminal) <==> (2*i-1 <= j <= 2*i)
                    if (curNode.split<0 && curNode.isTerminal && curNode.dataPoints.Length >= this.minNumSamples)
                    {
                        dataSampler.Shuffle(curNode.dataPoints.Length);
                        featureSampler.Shuffle(this.numFeatures);

                        Split bestSplit = findSplit.FindBestSplit(this.labelFeatureDataCoded, this.responses, curNode.dataPoints, this.workIndex, featureSampler, dataSampler, this.minNumSamples);

                        //qiangwu: the only way (bestSplit.feature < 0) not slippint is because this.dataColRange[dim]=1 for all
                        //         dimensions. I.e. the values all of data points in every dimension are the same (or in one bin)
                        if (bestSplit.feature >= 0)
                        {
                            curNode.split = bestSplit.feature;
                            curNode.gain = (float)bestSplit.gain;
                            curNode.splitValueCoded = bestSplit.iThresh + 0.2F; // add 0.2 to avoid boundary check or floating point rounding
                            curNode.splitValue = this.labelFeatureDataCoded.ConvertToOrigData(curNode.split, curNode.splitValueCoded);
                            //SplitOneDim(curNode.dataPoints, regionSplitDim, regionSplitPoint, out curNode.leftPoints, out curNode.rightPoints);
                        }
                    }
                    if (curNode.gain > maxGain)
                    {
                        maxGain = curNode.gain;
                        bestRegion = j;
                    }
                }

                if (bestRegion == -1)
                    break;

                TreeNode bestNode = this.tree[bestRegion];

                SplitOneDim(bestNode.dataPoints, bestNode.split, (int)bestNode.splitValueCoded, out bestNode.leftPoints, out bestNode.rightPoints);

                leftNode.isTerminal = true; leftNode.parent = bestRegion;
                leftNode.dataPoints = bestNode.leftPoints;

                rightNode.isTerminal = true; rightNode.parent = bestRegion;
                rightNode.dataPoints = bestNode.rightPoints;

                this.tree[2 * i + 1] = leftNode; this.tree[2 * i + 2] = rightNode;

                this.featureImportance[bestNode.split] += bestNode.gain;

                bestNode.leftChild = 2 * i + 1;
                bestNode.rightChild = 2 * i + 2;
                bestNode.isTerminal = false;
                bestNode.gain = -1;
                bestNode.dataPoints = null;
                bestNode.leftPoints = null;
                bestNode.rightPoints = null;
                GC.Collect(); // hope for the best.
            }

            //qiangwu: compute the response of newly created region (node)
            for (int i = 0; i < this.tree.Length; i++)
            {
                if (this.tree[i] != null && this.tree[i].isTerminal)
                {
                    Debug.Assert(this.tree[i].dataPoints.Length >= this.minNumSamples, "Regression Tree split has problems");
                    float v = boostTreeLoss.Response(this.tree[i].dataPoints, this.workIndex, iTree);
                    //round the regional value to 5 decimal point
                    //to remove/alleviate the differences due to floating point precision
                    //so that different algorithms produces the same model/results
            #if ROUND
                    this.tree[i].regionValue = (float)Math.Round(v, 5);
            #else
                    this.tree[i].regionValue = v;
            #endif //ROUND
                    this.tree[i].dataPoints = null;
                    this.tree[i].leftPoints = null;
                    this.tree[i].rightPoints = null;
                    GC.Collect();
                }
            }
        }
Exemplo n.º 2
0
        public RegressionTree(LabelFeatureDataCoded labelFeatureDataCoded, BoostTreeLoss boostTreeLoss, int iTree, int[] workIndex,
                              RandomSampler featureSampler, RandomSampler dataSampler,
                              int maxTreeSize, int minNumSamples,
                              IFindSplit findSplit, TempSpace tempSpace)
        {
            this.labelFeatureDataCoded = labelFeatureDataCoded;
            this.workIndex = workIndex;
            this.numFeatures = labelFeatureDataCoded.NumFeatures;
            this.maxTreeSize = maxTreeSize;
            this.featureImportance = new float[this.numFeatures];
            this.minNumSamples = minNumSamples;

            //distributed setting
            this.adjustFactor = 1.0F;

            InitTempSpace(tempSpace);
            BuildRegressionTree(boostTreeLoss, iTree, findSplit, dataSampler, featureSampler);
            GC.Collect(); // hope for the best!!!
        }