void SetParentLeafIndex(int nodeIndex, GBMTreeCreationItem parentItem) { if (parentItem.Values.Position == NodePositionType.Left) { parentItem.Parent.LeftIndex = nodeIndex; } else if (parentItem.Values.Position == NodePositionType.Right) { parentItem.Parent.RightIndex = nodeIndex; } }
void SplitWorker(F64Matrix observations, double[] residuals, double[] targets, double[] predictions, int[][] orderedElements, GBMTreeCreationItem parentItem, bool[] parentInSample, ConcurrentQueue <int> featureIndices, ConcurrentBag <GBMSplitResult> results) { int featureIndex = -1; while (featureIndices.TryDequeue(out featureIndex)) { FindBestSplit(observations, residuals, targets, predictions, orderedElements, parentItem, parentInSample, featureIndex, results); } }
void FindBestSplit(F64Matrix observations, double[] residuals, double[] targets, double[] predictions, int[][] orderedElements, GBMTreeCreationItem parentItem, bool[] parentInSample, int featureIndex, ConcurrentBag <GBMSplitResult> results) { var bestSplit = new GBMSplit { Depth = parentItem.Depth, FeatureIndex = -1, SplitIndex = -1, SplitValue = -1, Cost = double.MaxValue, LeftConstant = -1, RightConstant = -1, SampleCount = parentItem.Values.Samples }; var bestLeft = GBMSplitInfo.NewEmpty(); var bestRight = GBMSplitInfo.NewEmpty(); var left = GBMSplitInfo.NewEmpty(); var right = parentItem.Values.Copy(NodePositionType.Right); var orderedIndices = orderedElements[featureIndex]; var j = NextAllowedIndex(0, orderedIndices, parentInSample); // No allowed sample or valid split left. if (j >= orderedIndices.Length || orderedIndices.Length == 1) { return; } var currentIndex = orderedIndices[j]; m_loss.UpdateSplitConstants(ref left, ref right, targets[currentIndex], residuals[currentIndex]); var previousValue = observations.At(currentIndex, featureIndex); while (right.Samples > 0) { j = NextAllowedIndex(j + 1, orderedIndices, parentInSample); currentIndex = orderedIndices[j]; var currentValue = observations.At(currentIndex, featureIndex); if (Math.Min(left.Samples, right.Samples) >= m_minimumSplitSize) { if (previousValue != currentValue) { var cost = left.Cost + right.Cost; if (cost < bestSplit.Cost) { bestSplit.FeatureIndex = featureIndex; bestSplit.SplitIndex = j; bestSplit.SplitValue = (previousValue + currentValue) * .5; bestSplit.LeftError = left.Cost; bestSplit.RightError = right.Cost; bestSplit.Cost = cost; bestSplit.CostImprovement = parentItem.Values.Cost - cost; bestSplit.LeftConstant = left.BestConstant; bestSplit.RightConstant = right.BestConstant; bestLeft = left.Copy(); bestRight = right.Copy(); } } } m_loss.UpdateSplitConstants(ref left, ref right, targets[currentIndex], residuals[currentIndex]); previousValue = currentValue; } if (bestSplit.FeatureIndex != -1) { results.Add(new GBMSplitResult { BestSplit = bestSplit, Left = bestLeft, Right = bestRight }); } }