private void GetLeafSlotNames(int col, ref VBuffer <ReadOnlyMemory <char> > dst) { var numTrees = _ensemble.NumTrees; var names = dst.Values; if (Utils.Size(names) < _totalLeafCount) { names = new ReadOnlyMemory <char> [_totalLeafCount]; } int i = 0; int t = 0; foreach (var tree in _ensemble.GetTrees()) { for (int l = 0; l < tree.NumLeaves; l++) { names[i++] = string.Format("Tree{0:000}Leaf{1:000}", t, l).AsMemory(); } t++; } _host.Assert(i == _totalLeafCount); dst = new VBuffer <ReadOnlyMemory <char> >(_totalLeafCount, names, dst.Indices); }
public void GetLeafIds(ref VBuffer <float> dst) { EnsureCachedPosition(); _ectx.Assert(_input.Position >= 0); _ectx.Assert(_cachedPosition == _input.Position); if (_cachedLeafBuilderPosition != _input.Position) { if (_leafIdBuilder == null) { _leafIdBuilder = BufferBuilder <float> .CreateDefault(); } _leafIdBuilder.Reset(_numLeaves, false); var offset = 0; var trees = _ensemble.GetTrees(); for (int i = 0; i < trees.Length; i++) { _leafIdBuilder.AddFeature(offset + _leafIds[i], 1); offset += trees[i].NumLeaves; } _cachedLeafBuilderPosition = _input.Position; } _ectx.AssertValue(_leafIdBuilder); _leafIdBuilder.GetResult(ref dst); }
private static int CountLeaves(FastTreePredictionWrapper ensemble) { Contracts.AssertValue(ensemble); var trees = ensemble.GetTrees(); var numTrees = trees.Length; var totalLeafCount = 0; for (int i = 0; i < numTrees; i++) { totalLeafCount += trees[i].NumLeaves; } return(totalLeafCount); }