/// <summary>
        ///     This function is used to query the amount of reserve needed to run dropout with the input dimensions given by
        ///     xDesc.
        ///     The same reserve space is expected to be passed to cudnnDropoutForward and cudnnDropoutBackward, and its contents
        ///     is
        ///     expected to remain unchanged between cudnnDropoutForward and cudnnDropoutBackward calls.
        /// </summary>
        /// <param name="xDesc">Handle to a previously initialized tensor descriptor, describing input to a dropout operation.</param>
        public SizeT GetDropoutReserveSpaceSize(TensorDescriptor xDesc)
        {
            var sizeInBytes = new SizeT();
            var res         = CudaDNNNativeMethods.cudnnDropoutGetReserveSpaceSize(xDesc.Desc, ref sizeInBytes);

            if (res != cudnnStatus.Success)
            {
                throw new CudaDNNException(res);
            }
            return(sizeInBytes);
        }
Exemplo n.º 2
0
        /// <summary>
        ///     This function is used to query the amount of reserve needed to run dropout with the input dimensions given by
        ///     xDesc.
        ///     The same reserve space is expected to be passed to cudnnDropoutForward and cudnnDropoutBackward, and its contents
        ///     is
        ///     expected to remain unchanged between cudnnDropoutForward and cudnnDropoutBackward calls.
        /// </summary>
        /// <param name="xDesc">Handle to a previously initialized tensor descriptor, describing input to a dropout operation.</param>
        public SizeT GetDropoutReserveSpaceSize(TensorDescriptor xDesc)
        {
            var sizeInBytes = new SizeT();
            var res         = CudaDNNNativeMethods.cudnnDropoutGetReserveSpaceSize(xDesc.Desc, ref sizeInBytes);

            // Debug.WriteLine("{0:G}, {1}: {2}", DateTime.Now, "cudnnDropoutGetReserveSpaceSize", res);
            if (res != cudnnStatus.Success)
            {
                throw new CudaDNNException(res);
            }
            return(sizeInBytes);
        }