Ejemplo n.º 1
0
                public State(IExceptionContext ectx, IRow input, FastTreePredictionWrapper ensemble, int numLeaves, int featureIndex)
                {
                    Contracts.AssertValue(ectx);
                    _ectx = ectx;
                    _ectx.AssertValue(input);
                    _ectx.AssertValue(ensemble);
                    _ectx.Assert(ensemble.NumTrees > 0);
                    _input     = input;
                    _ensemble  = ensemble;
                    _numTrees  = _ensemble.NumTrees;
                    _numLeaves = numLeaves;

                    _src           = default(VBuffer <float>);
                    _featureGetter = input.GetGetter <VBuffer <float> >(featureIndex);

                    _cachedPosition = -1;
                    _leafIds        = new int[_numTrees];
                    _pathIds        = new List <int> [_numTrees];
                    for (int i = 0; i < _numTrees; i++)
                    {
                        _pathIds[i] = new List <int>();
                    }

                    _cachedLeafBuilderPosition = -1;
                    _cachedPathBuilderPosition = -1;
                }
Ejemplo n.º 2
0
        private static int CountLeaves(FastTreePredictionWrapper ensemble)
        {
            Contracts.AssertValue(ensemble);

            var trees          = ensemble.GetTrees();
            var numTrees       = trees.Length;
            var totalLeafCount = 0;

            for (int i = 0; i < numTrees; i++)
            {
                totalLeafCount += trees[i].NumLeaves;
            }
            return(totalLeafCount);
        }
Ejemplo n.º 3
0
        public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, Arguments args, IPredictor predictor)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(LoaderSignature);
            _host.CheckValue(args, nameof(args));
            _host.CheckValue(predictor, nameof(predictor));

            if (predictor is CalibratedPredictorBase)
            {
                predictor = ((CalibratedPredictorBase)predictor).SubPredictor;
            }
            _ensemble = predictor as FastTreePredictionWrapper;
            _host.Check(_ensemble != null, "Predictor in model file does not have compatible type");

            _totalLeafCount = CountLeaves(_ensemble);
        }