void AssertGBMNode(GBMNode expected, GBMNode actual) { Assert.AreEqual(expected.Depth, actual.Depth); Assert.AreEqual(expected.FeatureIndex, actual.FeatureIndex); Assert.AreEqual(expected.LeftConstant, actual.LeftConstant, m_delta); Assert.AreEqual(expected.LeftError, actual.LeftError, m_delta); Assert.AreEqual(expected.LeftIndex, actual.LeftIndex); Assert.AreEqual(expected.RightConstant, actual.RightConstant, m_delta); Assert.AreEqual(expected.RightError, actual.RightError, m_delta); Assert.AreEqual(expected.RightIndex, actual.RightIndex); Assert.AreEqual(expected.SampleCount, actual.SampleCount); Assert.AreEqual(expected.SplitValue, actual.SplitValue, m_delta); }
private static List <GBMNode> ConvertXGBoostNodesToGBMNodes(string textTree) { var newLine = new string[] { "\n" }; var lines = textTree.Split(newLine, StringSplitOptions.RemoveEmptyEntries); var nodes = new List <GBMNode> { // Add special root node for sharplearning new GBMNode { FeatureIndex = -1, SplitValue = -1, LeftConstant = 0.5, RightConstant = 0.5, }, }; // Order lines by node index and remove booster line. var ordered = lines.Where(l => !l.Contains("booster")).ToArray(); var orderedLines = ordered .OrderBy(l => ParseNodeIndex(l)) .ToDictionary(l => ParseNodeIndex(l), l => l); var nodeIndex = 1; foreach (var line in orderedLines.Values) { if (IsLeaf(line)) { // Leafs are not added as nodes, leaf values are included in the split nodes. continue; } else { var featureIndex = ParseFeatureIndex(line); var splitValue = ParseSplitValue(line); var yesIndex = ParseYesIndex(line); var noIndex = ParseNoIndex(line); var node = new GBMNode { FeatureIndex = featureIndex, SplitValue = splitValue, LeftConstant = -1, LeftIndex = -1, RightConstant = -1, RightIndex = -1 }; var left = orderedLines[yesIndex]; if (IsLeaf(left)) { node.LeftIndex = -1; node.LeftConstant = ParseLeafValue(left); } else { nodeIndex++; node.LeftIndex = nodeIndex; } var right = orderedLines[noIndex]; if (IsLeaf(right)) { node.RightIndex = -1; node.RightConstant = ParseLeafValue(right); } else { nodeIndex++; node.RightIndex = nodeIndex; } nodes.Add(node); } } return(nodes); }