public State(IExceptionContext ectx, IRow input, FastTreePredictionWrapper ensemble, int numLeaves, int featureIndex) { Contracts.AssertValue(ectx); _ectx = ectx; _ectx.AssertValue(input); _ectx.AssertValue(ensemble); _ectx.Assert(ensemble.NumTrees > 0); _input = input; _ensemble = ensemble; _numTrees = _ensemble.NumTrees; _numLeaves = numLeaves; _src = default(VBuffer <float>); _featureGetter = input.GetGetter <VBuffer <float> >(featureIndex); _cachedPosition = -1; _leafIds = new int[_numTrees]; _pathIds = new List <int> [_numTrees]; for (int i = 0; i < _numTrees; i++) { _pathIds[i] = new List <int>(); } _cachedLeafBuilderPosition = -1; _cachedPathBuilderPosition = -1; }
private static int CountLeaves(FastTreePredictionWrapper ensemble) { Contracts.AssertValue(ensemble); var trees = ensemble.GetTrees(); var numTrees = trees.Length; var totalLeafCount = 0; for (int i = 0; i < numTrees; i++) { totalLeafCount += trees[i].NumLeaves; } return(totalLeafCount); }
public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, Arguments args, IPredictor predictor) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(LoaderSignature); _host.CheckValue(args, nameof(args)); _host.CheckValue(predictor, nameof(predictor)); if (predictor is CalibratedPredictorBase) { predictor = ((CalibratedPredictorBase)predictor).SubPredictor; } _ensemble = predictor as FastTreePredictionWrapper; _host.Check(_ensemble != null, "Predictor in model file does not have compatible type"); _totalLeafCount = CountLeaves(_ensemble); }