Exemplo n.º 1
0
        /// <summary>
        /// Convolutions the backward filter.
        /// </summary>
        /// <param name="algo">The algo.</param>
        /// <param name="cd">The cd.</param>
        /// <param name="workspace">The workspace.</param>
        /// <param name="x">The x.</param>
        /// <param name="dy">The dy.</param>
        /// <param name="dw">The dw.</param>
        public static void ConvolutionBackwardFilter(DNNConvolutionBwdFilterAlgo algo, Cpu.ConvolutionDesc2d cd, CudaStorage workspace, NDArray x, NDArray dy, NDArray dw)
        {
            using (var dnn = CudaHelpers.TSContextForTensor(x).DNNForTensor(x))
            {
                var convDesc = GetConvDescriptor(cd, x.ElementType);

                using (var workspacePtr = new CudaDeviceVariable <byte>(workspace.DevicePtrAtElement(0), false, workspace.ByteLength))
                    using (var xPtr = GetDeviceVar(x))
                        using (var dyPtr = GetDeviceVar(dy))
                            using (var dwPtr = GetDeviceVar(dw))
                                using (var xDesc = GetDescriptor(x))
                                    using (var dyDesc = GetDescriptor(dy))
                                        using (var dwDesc = GetFilterDescriptor(dw))
                                        {
                                            dnn.Value.ConvolutionBackwardFilter(1,
                                                                                xDesc, xPtr,
                                                                                dyDesc, dyPtr,
                                                                                convDesc,
                                                                                (cudnnConvolutionBwdFilterAlgo)algo,
                                                                                workspacePtr,
                                                                                0,
                                                                                dwDesc, dwPtr);
                                        }
            }
        }
Exemplo n.º 2
0
        /// <summary>
        /// Convs the forward.
        /// </summary>
        /// <param name="algo">The algo.</param>
        /// <param name="cd">The cd.</param>
        /// <param name="workspace">The workspace.</param>
        /// <param name="x">The x.</param>
        /// <param name="w">The w.</param>
        /// <param name="y">The y.</param>
        public static void ConvForward(DNNConvolutionFwdAlgo algo, Cpu.ConvolutionDesc2d cd, CudaStorage workspace, NDArray x, NDArray w, NDArray y)
        {
            using (var dnn = CudaHelpers.TSContextForTensor(x).DNNForTensor(x))
            {
                var convDesc = GetConvDescriptor(cd, x.ElementType);

                using (var workspacePtr = new CudaDeviceVariable <byte>(workspace.DevicePtrAtElement(0), false, workspace.ByteLength))
                    using (var xPtr = GetDeviceVar(x))
                        using (var wPtr = GetDeviceVar(w))
                            using (var yPtr = GetDeviceVar(y))
                                using (var xDesc = GetDescriptor(x))
                                    using (var wDesc = GetFilterDescriptor(w))
                                        using (var yDesc = GetDescriptor(y))
                                        {
                                            dnn.Value.ConvolutionForward(1,
                                                                         xDesc, xPtr,
                                                                         wDesc, wPtr,
                                                                         convDesc,
                                                                         (cudnnConvolutionFwdAlgo)algo,
                                                                         workspacePtr,
                                                                         0,
                                                                         yDesc, yPtr);
                                        }
            }
        }