예제 #1
0
        /// <summary>
        /// Runs the m v float.
        /// </summary>
        /// <param name="result">The result.</param>
        /// <param name="mat">The mat.</param>
        /// <param name="vec">The vec.</param>
        /// <exception cref="ArgumentException">lhs must be contiguous in the last dimension</exception>
        private static void Run_M_V_float(NDArray result, NDArray mat, NDArray vec)
        {
            // Require lhs to be row-major. This means we must tell BLAS to IntTranspose it (BLAS expects column-major matrices)
            if (mat.Strides[1] != 1)
            {
                throw new ArgumentException("lhs must be contiguous in the last dimension");
            }

            unsafe
            {
                var yPtr = (float *)CpuNativeHelpers.GetBufferStart(result);
                var aPtr = (float *)CpuNativeHelpers.GetBufferStart(mat);
                var xPtr = (float *)CpuNativeHelpers.GetBufferStart(vec);

                byte  trans = (byte)'t';
                int   m     = (int)mat.Shape[1];
                int   n     = (int)mat.Shape[0];
                int   incx  = (int)vec.Strides[0];
                int   lda   = (int)mat.Strides[0];
                int   incy  = (int)result.Strides[0];
                float alpha = 1;
                float beta  = 0;
                OpenBlasNative.sgemv_(&trans, &m, &n, &alpha, aPtr, &lda, xPtr, &incx, &beta, yPtr, &incy);
            }
        }
예제 #2
0
        /// <summary>
        /// Runs the m v double.
        /// </summary>
        /// <param name="result">The result.</param>
        /// <param name="lhs">The LHS.</param>
        /// <param name="rhs">The RHS.</param>
        /// <exception cref="ArgumentException">lhs must be contiguous in the last dimension</exception>
        private static void Run_M_V_double(NDArray result, NDArray lhs, NDArray rhs)
        {
            // Require lhs to be row-major. This means we must tell BLAS to IntTranspose it (BLAS expects column-major matrices)
            if (lhs.Strides[1] != 1)
            {
                throw new ArgumentException("lhs must be contiguous in the last dimension");
            }

            unsafe
            {
                var resultPtr = (double *)CpuNativeHelpers.GetBufferStart(result);
                var lhsPtr    = (double *)CpuNativeHelpers.GetBufferStart(lhs);
                var rhsPtr    = (double *)CpuNativeHelpers.GetBufferStart(rhs);

                byte   trans = (byte)'t';
                int    m     = (int)rhs.Shape[1];
                int    n     = (int)lhs.Shape[0];
                int    lda   = (int)rhs.Strides[0];
                int    ldb   = (int)lhs.Strides[0];
                int    ldc   = (int)result.Strides[0];
                double alpha = 1;
                double beta  = 0;
                OpenBlasNative.dgemv_(&trans, &m, &n, &alpha, rhsPtr, &lda, lhsPtr, &ldb, &beta, resultPtr, &ldc);
            }
        }
예제 #3
0
        /// <summary>
        /// Gemms the op.
        /// </summary>
        /// <param name="transA">The trans a.</param>
        /// <param name="transB">The trans b.</param>
        /// <param name="alpha">The alpha.</param>
        /// <param name="a">a.</param>
        /// <param name="b">The b.</param>
        /// <param name="beta">The beta.</param>
        /// <param name="c">The c.</param>
        /// <exception cref="ArgumentException">
        /// a must be contiguous in the first dimension (column major / fortran order)
        /// or
        /// b must be contiguous in the first dimension (column major / fortran order)
        /// or
        /// c must be contiguous in the first dimension (column major / fortran order)
        /// </exception>
        /// <exception cref="NotSupportedException">CPU GEMM with element type " + c.ElementType + " not supported</exception>
        private static void GemmOp(BlasOp transA, BlasOp transB, float alpha, NDArray a, NDArray b, float beta, NDArray c)
        {
            if (a.Strides[0] != 1)
            {
                throw new ArgumentException("a must be contiguous in the first dimension (column major / fortran order)");
            }
            if (b.Strides[0] != 1)
            {
                throw new ArgumentException("b must be contiguous in the first dimension (column major / fortran order)");
            }
            if (c.Strides[0] != 1)
            {
                throw new ArgumentException("c must be contiguous in the first dimension (column major / fortran order)");
            }

            unsafe
            {
                // dimensons: (m x k) * (k * n) = (m x n)
                bool nta    = transA == BlasOp.NonTranspose;
                bool ntb    = transB == BlasOp.NonTranspose;
                byte transa = (byte)transA;
                byte transb = (byte)transB;
                int  m      = (int)a.Shape[nta ? 0 : 1];
                int  k      = (int)b.Shape[ntb ? 0 : 1];
                int  n      = (int)b.Shape[ntb ? 1 : 0];
                int  lda    = (int)a.Strides[1];
                int  ldb    = (int)b.Strides[1];
                int  ldc    = (int)c.Strides[1];

                if (c.ElementType == DType.Float32)
                {
                    var aPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(a);
                    var bPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(b);
                    var cPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(c);

                    OpenBlasNative.sgemm_(&transa, &transb, &m, &n, &k, &alpha, aPtrSingle, &lda, bPtrSingle, &ldb, &beta, cPtrSingle, &ldc);
                }
                else if (c.ElementType == DType.Float64)
                {
                    var aPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(a);
                    var bPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(b);
                    var cPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(c);
                    var alphaDouble = (double)alpha;
                    var betaDouble  = (double)beta;
                    OpenBlasNative.dgemm_(&transa, &transb, &m, &n, &k, &alphaDouble, aPtrDouble, &lda, bPtrDouble, &ldb, &betaDouble, cPtrDouble, &ldc);
                }
                else
                {
                    throw new NotSupportedException("CPU GEMM with element type " + c.ElementType + " not supported");
                }
            }
        }
예제 #4
0
        /// <summary>
        /// Runs the dot double.
        /// </summary>
        /// <param name="result">The result.</param>
        /// <param name="lhs">The LHS.</param>
        /// <param name="rhs">The RHS.</param>
        private static void Run_Dot_double(NDArray result, NDArray lhs, NDArray rhs)
        {
            unsafe
            {
                var resultPtr = (double *)CpuNativeHelpers.GetBufferStart(result);
                var lhsPtr    = (double *)CpuNativeHelpers.GetBufferStart(lhs);
                var rhsPtr    = (double *)CpuNativeHelpers.GetBufferStart(rhs);

                int n         = (int)lhs.Shape[0];
                int incx      = (int)lhs.Strides[0];
                int incy      = (int)rhs.Strides[0];
                *   resultPtr = OpenBlasNative.ddot_(&n, lhsPtr, &incx, rhsPtr, &incy);
            }
        }