Beispiel #1
0
        public override void DoMultiply(Volume <float> right, Volume <float> result)
        {
            var resultStorage = result.Storage as VolumeStorage;

            if (resultStorage == null)
            {
                throw new ArgumentException($"{nameof(result)} storage should be VolumeStorage", nameof(result));
            }

            var rightStorage = right.Storage as VolumeStorage;

            if (rightStorage == null)
            {
                throw new ArgumentException($"{nameof(right)} storage should be VolumeStorage", nameof(right));
            }

            // Copy to device if not already done
            this._volumeStorage.CopyToDevice();
            rightStorage.CopyToDevice();
            resultStorage.CopyToDevice();

            var aStorage = this._volumeStorage;
            var bStorage = rightStorage;

            if (bStorage.Shape.TotalLength > aStorage.Shape.TotalLength)
            {
                aStorage = rightStorage;
                bStorage = this._volumeStorage;
            }
            var bShape = bStorage.Shape;

            var n = aStorage.Shape.GetDimension(3);
            var c = aStorage.Shape.GetDimension(2);
            var h = aStorage.Shape.GetDimension(1);
            var w = aStorage.Shape.GetDimension(0);

            // Add tensors
            using (var descA = new TensorDescriptor())
                using (var descB = new TensorDescriptor())
                    using (var descC = new TensorDescriptor())
                    {
                        descA.SetTensor4dDescriptor(cudnnTensorFormat.NCHW, cudnnDataType.Float, n, c, h, w);
                        descB.SetTensor4dDescriptor(cudnnTensorFormat.NCHW, cudnnDataType.Float, bShape.GetDimension(3), bShape.GetDimension(2), bShape.GetDimension(1),
                                                    bShape.GetDimension(0));
                        descC.SetTensor4dDescriptor(cudnnTensorFormat.NCHW, cudnnDataType.Float, n, c, h, w);

                        using (var opt = new OpTensorDescriptor(this._context.CudnnContext))
                        {
                            opt.SetOpTensorDescriptor(
                                cudnnOpTensorOp.OpTensorMul,
                                cudnnDataType.Float,
                                cudnnNanPropagation.PropagateNan);

                            var one  = 1.0f;
                            var zero = 0.0f;

                            var status = CudaDNNNativeMethods.cudnnOpTensor(
                                this._context.CudnnContext.Handle,
                                opt.Desc,
                                ref one, descA.Desc, aStorage.DeviceBuffer.DevicePointer,
                                ref one, descB.Desc, bStorage.DeviceBuffer.DevicePointer,
                                ref zero, descC.Desc, resultStorage.DeviceBuffer.DevicePointer);

                            if (status != cudnnStatus.Success)
                            {
                                throw new Exception(CudaDNNNativeMethods.cudnnGetErrorString(status));
                            }

                            resultStorage.Location = DataLocation.Device;
                        }
                    }
        }
Beispiel #2
0
        private void Op(Volume <double> right, cudnnOpTensorOp op, Volume <double> result)
        {
            var resultStorage = result.Storage as VolumeStorage;

            if (resultStorage == null)
            {
                throw new ArgumentException($"{nameof(result)} storage should be VolumeStorage", nameof(result));
            }

            VolumeStorage rightStorage = null;

            if (right != null)
            {
                rightStorage = right.Storage as VolumeStorage;
                if (rightStorage == null)
                {
                    throw new ArgumentException($"{nameof(right)} storage should be VolumeStorage", nameof(right));
                }
            }

            // Copy to device if not already done
            this._volumeStorage.CopyToDevice();
            rightStorage?.CopyToDevice();
            resultStorage.CopyToDevice();

            var           aStorage = this._volumeStorage;
            Shape         bShape   = null;
            VolumeStorage bStorage = null;

            if (rightStorage != null)
            {
                bStorage = rightStorage;
                if (bStorage.Shape.TotalLength > aStorage.Shape.TotalLength)
                {
                    aStorage = rightStorage;
                    bStorage = this._volumeStorage;
                }

                bShape = bStorage.Shape;
            }

            var n = aStorage.Shape.Dimensions[3];
            var c = aStorage.Shape.Dimensions[2];
            var h = aStorage.Shape.Dimensions[1];
            var w = aStorage.Shape.Dimensions[0];

            // Add tensors
            using var descA = new TensorDescriptor();
            using var descB = new TensorDescriptor();
            using var descC = new TensorDescriptor();

            descA.SetTensor4dDescriptor(cudnnTensorFormat.NCHW, cudnnDataType.Double, n, c, h, w);
            if (bShape != null)
            {
                descB.SetTensor4dDescriptor(cudnnTensorFormat.NCHW, cudnnDataType.Double, bShape.Dimensions[3], bShape.Dimensions[2], bShape.Dimensions[1],
                                            bShape.Dimensions[0]);
            }

            descC.SetTensor4dDescriptor(cudnnTensorFormat.NCHW, cudnnDataType.Double, n, c, h, w);

            using var opt = new OpTensorDescriptor(this._context.CudnnContext);

            opt.SetOpTensorDescriptor(
                op,
                cudnnDataType.Double,
                cudnnNanPropagation.PropagateNan);

            var one  = 1.0;
            var zero = 0.0;

            var status = CudaDNNNativeMethods.cudnnOpTensor(
                this._context.CudnnContext.Handle,
                opt.Desc,
                ref one, descA.Desc, aStorage.DeviceBuffer.DevicePointer,
                ref one, bStorage != null ? descB.Desc : descA.Desc, bStorage?.DeviceBuffer.DevicePointer ?? aStorage.DeviceBuffer.DevicePointer,
                ref zero, descC.Desc, resultStorage.DeviceBuffer.DevicePointer);

            if (status != cudnnStatus.Success)
            {
                throw new Exception(CudaDNNNativeMethods.cudnnGetErrorString(status));
            }

            resultStorage.Location = DataLocation.Device;
        }