Пример #1
0
        public void Predict(ref VBuffer <Float> features,
                            ref VBuffer <Float> predictedValues,
                            ref XGBoostTreeBuffer internalBuffer,
                            bool outputMargin = true,
                            int ntreeLimit    = 0)
        {
            PredictOneOff(ref features, ref predictedValues, ref internalBuffer, outputMargin, ntreeLimit);

#if (DEBUG && MORE_CHECKING)
            // This part checks that the function PredictOneOff which relies on a customized version
            // of XGBoost produces the same result as the official API.
            // This makes the prediction terribly slow as the prediction are called twice
            // and the second call (PredictN) cannot be parallelized (lock protected).
            VBuffer <Float> check = new VBuffer <float>();
            DMatrix         data;
            if (features.IsDense)
            {
                data = new DMatrix(features.Values, 1, (uint)features.Count);
            }
            else
            {
                int nb      = features.Count;
                var indptr  = new ulong[] { 0, (uint)nb };
                var indices = new uint[nb];
                for (int i = 0; i < nb; ++i)
                {
                    indices[i] = (uint)features.Indices[i];
                }
                data = new DMatrix((uint)features.Length, indptr, indices, features.Values, 1, (uint)nb);
            }

            PredictN(data, ref check, outputMargin, ntreeLimit);
            if (check.Count != predictedValues.Count)
            {
                string message =
                    string.Format(
                        "Count={0} Length={1} IsDense={2}\nValues={3}\nIndices={4}\nCustom Ouput={5}\nOfficial API={6}",
                        features.Count, features.Length, features.IsDense,
                        features.Values == null
                            ? ""
                            : string.Join(", ", features.Values.Select(c => c.ToString()).ToArray()),
                        features.Indices == null
                            ? ""
                            : string.Join(", ", features.Indices.Select(c => c.ToString()).ToArray()),
                        predictedValues.Values == null
                            ? ""
                            : string.Join(", ", predictedValues.Values.Select(c => c.ToString()).ToArray()),
                        check.Values == null
                            ? ""
                            : string.Join(", ", check.Values.Select(c => c.ToString()).ToArray()));
                throw Contracts.Except("Mismatch between official API and custom API (dimension).\n" + message);
            }
            for (int i = 0; i < check.Count; ++i)
            {
                if (Math.Abs(check.Values[0] - predictedValues.Values[0]) > 1e-5)
                {
                    string message =
                        string.Format(
                            "Count={0} Length={1} IsDense={2}\nValues={3}\nIndices={4}\nCustom Ouput={5}\nOfficial API={6}",
                            features.Count, features.Length, features.IsDense,
                            features.Values == null
                                ? ""
                                : string.Join(", ", features.Values.Select(c => c.ToString()).ToArray()),
                            features.Indices == null
                                ? ""
                                : string.Join(", ", features.Indices.Select(c => c.ToString()).ToArray()),
                            predictedValues.Values == null
                                ? ""
                                : string.Join(", ", predictedValues.Values.Select(c => c.ToString()).ToArray()),
                            check.Values == null
                                ? ""
                                : string.Join(", ", check.Values.Select(c => c.ToString()).ToArray()));
                    PredictOneOff(ref features, ref predictedValues, ref internalBuffer, outputMargin, ntreeLimit);
                    message += string.Format("\nSecond computation\n{0}", predictedValues.Values == null
                        ? ""
                        : string.Join(", ", predictedValues.Values.Select(c => c.ToString()).ToArray()));
                    throw Contracts.Except("Mismatch between official API and custom API (output).\n" + message);
                }
            }
#endif
        }
Пример #2
0
        /// <summary>
        /// Predict with data.
        /// This function uses a modified API which does not use caches.
        /// </summary>
        /// <param name="vbuf">one row</param>
        /// <param name="predictedValues">Results of the prediction</param>
        /// <param name="internalBuffer">buffers allocated by Microsoft.ML and given to XGBoost to avoid XGBoost to allocated caches on its own</param>
        /// <param name="outputMargin">Whether to output the raw untransformed margin value.</param>
        /// <param name="ntreeLimit">Limit number of trees in the prediction; defaults to 0 (use all trees).</param>
        public void PredictOneOff(ref VBuffer <Float> vbuf, ref VBuffer <Float> predictedValues,
                                  ref XGBoostTreeBuffer internalBuffer, bool outputMargin = true, int ntreeLimit = 0)
        {
            // REVIEW xadupre: XGBoost can produce an output per tree (pred_leaf=true)
            // When this option is on, the output will be a matrix of (nsample, ntrees)
            // with each record indicating the predicted leaf index of each sample in each tree.
            // Note that the leaf index of a tree is unique per tree, so you may find leaf 1
            // in both tree 1 and tree 0.
            // if (pred_leaf)
            //    option_mask |= 0x02;
            // This might be an interesting feature to implement.

            int optionMask = 0x00;

            if (outputMargin)
            {
                optionMask |= 0x01;
            }

            Contracts.Check(internalBuffer != null);

            uint length       = 0;
            uint lengthBuffer = 0;
            uint nb           = (uint)vbuf.Count;

            // This function relies on a modified API. Instead of letting XGBoost handle its own caches,
            // the function calls XGBoosterPredictOutputSize to know what cache size is required.
            // Microsoft.ML allocated the caches and gives them to XGBoost.
            // First, we allocated the cache for the features. Only then XGBoost
            // will be able to known the required cache size.
#if (XGB_EXTENDED)
            internalBuffer.ResizeEntries(nb, vbuf.Length);
#else
            internalBuffer.ResizeEntries(nb);
#endif

            unsafe
            {
                fixed(float *p = vbuf.Values)
                fixed(int *i        = vbuf.Indices)
                fixed(byte *entries = internalBuffer.XGBoostEntries)
                {
                    WrappedXGBoostInterface.XGBoosterCopyEntries((IntPtr)entries, ref nb, p, vbuf.IsDense ? null : i, float.NaN);
                    WrappedXGBoostInterface.XGBoosterPredictOutputSize(_handle,
                                                                       (IntPtr)entries, nb, optionMask, (uint)ntreeLimit, ref length, ref lengthBuffer);
                }
            }

            // Then we allocated the cache for the prediction.
            internalBuffer.ResizeOutputs(length, lengthBuffer, ref predictedValues);

            unsafe
            {
                fixed(byte *entries = internalBuffer.XGBoostEntries)
                fixed(float *ppreds      = predictedValues.Values)
                fixed(float *ppredBuffer = internalBuffer.PredBuffer)
                fixed(uint *ppredCounter = internalBuffer.PredCounter)
                {
                    WrappedXGBoostInterface.XGBoosterPredictNoInsideCache(_handle,
                                                                          (IntPtr)entries, nb, optionMask, (uint)ntreeLimit, length, lengthBuffer, ppreds, ppredBuffer, ppredCounter
#if (XGB_EXTENDED)
                                                                          , internalBuffer.RegTreeFVec
#endif
                                                                          );
                }
            }
        }