/// <summary>
        /// Runs the m v float.
        /// </summary>
        /// <param name="result">The result.</param>
        /// <param name="mat">The mat.</param>
        /// <param name="vec">The vec.</param>
        /// <exception cref="ArgumentException">lhs must be contiguous in the last dimension</exception>
        private static void Run_M_V_float(Tensor result, Tensor mat, Tensor vec)
        {
            // Require lhs to be row-major. This means we must tell BLAS to transpose it (BLAS expects column-major matrices)
            if (mat.Strides[1] != 1)
            {
                throw new ArgumentException("lhs must be contiguous in the last dimension");
            }

            unsafe
            {
                var yPtr = (float *)CpuNativeHelpers.GetBufferStart(result);
                var aPtr = (float *)CpuNativeHelpers.GetBufferStart(mat);
                var xPtr = (float *)CpuNativeHelpers.GetBufferStart(vec);

                byte  trans = (byte)'t';
                int   m     = (int)mat.Shape[1];
                int   n     = (int)mat.Shape[0];
                int   incx  = (int)vec.Strides[0];
                int   lda   = (int)mat.Strides[0];
                int   incy  = (int)result.Strides[0];
                float alpha = 1;
                float beta  = 0;
                OpenBlasNative.sgemv_(&trans, &m, &n, &alpha, aPtr, &lda, xPtr, &incx, &beta, yPtr, &incy);
            }
        }
        /// <summary>
        /// Runs the m v double.
        /// </summary>
        /// <param name="result">The result.</param>
        /// <param name="lhs">The LHS.</param>
        /// <param name="rhs">The RHS.</param>
        /// <exception cref="ArgumentException">lhs must be contiguous in the last dimension</exception>
        private static void Run_M_V_double(Tensor result, Tensor lhs, Tensor rhs)
        {
            // Require lhs to be row-major. This means we must tell BLAS to transpose it (BLAS expects column-major matrices)
            if (lhs.Strides[1] != 1)
            {
                throw new ArgumentException("lhs must be contiguous in the last dimension");
            }

            unsafe
            {
                var resultPtr = (double *)CpuNativeHelpers.GetBufferStart(result);
                var lhsPtr    = (double *)CpuNativeHelpers.GetBufferStart(lhs);
                var rhsPtr    = (double *)CpuNativeHelpers.GetBufferStart(rhs);

                byte   trans = (byte)'t';
                int    m     = (int)rhs.Shape[1];
                int    n     = (int)lhs.Shape[0];
                int    lda   = (int)rhs.Strides[0];
                int    ldb   = (int)lhs.Strides[0];
                int    ldc   = (int)result.Strides[0];
                double alpha = 1;
                double beta  = 0;
                OpenBlasNative.dgemv_(&trans, &m, &n, &alpha, rhsPtr, &lda, lhsPtr, &ldb, &beta, resultPtr, &ldc);
            }
        }
示例#3
0
        private static void GemmOp(BlasOp transA, BlasOp transB, float alpha, Tensor a, Tensor b, float beta, Tensor c)
        {
            if (a.Strides[0] != 1)
            {
                throw new ArgumentException("a must be contiguous in the first dimension (column major / fortran order)");
            }

            if (b.Strides[0] != 1)
            {
                throw new ArgumentException("b must be contiguous in the first dimension (column major / fortran order)");
            }

            if (c.Strides[0] != 1)
            {
                throw new ArgumentException("c must be contiguous in the first dimension (column major / fortran order)");
            }

            unsafe
            {
                // dimensons: (m x k) * (k * n) = (m x n)
                bool nta    = transA == BlasOp.NonTranspose;
                bool ntb    = transB == BlasOp.NonTranspose;
                byte transa = (byte)transA;
                byte transb = (byte)transB;
                int  m      = (int)a.Sizes[nta ? 0 : 1];
                int  k      = (int)b.Sizes[ntb ? 0 : 1];
                int  n      = (int)b.Sizes[ntb ? 1 : 0];
                int  lda    = (int)a.Strides[1];
                int  ldb    = (int)b.Strides[1];
                int  ldc    = (int)c.Strides[1];

                if (c.ElementType == DType.Float32)
                {
                    float *aPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(a);
                    float *bPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(b);
                    float *cPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(c);

                    SGEMM sgemm = new SGEMM();
                    sgemm.Run(System.Text.ASCIIEncoding.ASCII.GetString(&transa, 1), System.Text.ASCIIEncoding.ASCII.GetString(&transb, 1), m, n, k, alpha, aPtrSingle, lda, bPtrSingle, ldb, beta, cPtrSingle, ldc);
                }
                else if (c.ElementType == DType.Float64)
                {
                    double *aPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(a);
                    double *bPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(b);
                    double *cPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(c);
                    double  alphaDouble = alpha;
                    double  betaDouble  = beta;
                    OpenBlasNative.dgemm_(&transa, &transb, &m, &n, &k, &alphaDouble, aPtrDouble, &lda, bPtrDouble, &ldb, &betaDouble, cPtrDouble, &ldc);
                }
                else
                {
                    throw new NotSupportedException("CPU GEMM with element type " + c.ElementType + " not supported");
                }
            }
        }
示例#4
0
        private static void GemmOp(BlasOp transA, BlasOp transB, float alpha, Tensor a, Tensor b, float beta, Tensor c)
        {
            if (a.Strides[0] != 1)
            {
                throw new ArgumentException("a must be contiguous in the first dimension (column major / fortran order)");
            }

            if (b.Strides[0] != 1)
            {
                throw new ArgumentException("b must be contiguous in the first dimension (column major / fortran order)");
            }

            if (c.Strides[0] != 1)
            {
                throw new ArgumentException("c must be contiguous in the first dimension (column major / fortran order)");
            }

            unsafe
            {
                // dimensons: (m x k) * (k * n) = (m x n)
                var nta    = transA == BlasOp.NonTranspose;
                var ntb    = transB == BlasOp.NonTranspose;
                var transa = (byte)transA;
                var transb = (byte)transB;
                var m      = (int)a.Sizes[nta ? 0 : 1];
                var k      = (int)b.Sizes[ntb ? 0 : 1];
                var n      = (int)b.Sizes[ntb ? 1 : 0];
                var lda    = (int)a.Strides[1];
                var ldb    = (int)b.Strides[1];
                var ldc    = (int)c.Strides[1];

                if (c.ElementType == DType.Float32)
                {
                    var aPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(a);
                    var bPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(b);
                    var cPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(c);

                    OpenBlasNative.sgemm_(&transa, &transb, &m, &n, &k, &alpha, aPtrSingle, &lda, bPtrSingle, &ldb, &beta, cPtrSingle, &ldc);
                }
                else if (c.ElementType == DType.Float64)
                {
                    var    aPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(a);
                    var    bPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(b);
                    var    cPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(c);
                    double alphaDouble = alpha;
                    double betaDouble  = beta;
                    OpenBlasNative.dgemm_(&transa, &transb, &m, &n, &k, &alphaDouble, aPtrDouble, &lda, bPtrDouble, &ldb, &betaDouble, cPtrDouble, &ldc);
                }
                else
                {
                    throw new NotSupportedException("CPU GEMM with element type " + c.ElementType + " not supported");
                }
            }
        }
        /// <summary>
        /// Runs the dot double.
        /// </summary>
        /// <param name="result">The result.</param>
        /// <param name="lhs">The LHS.</param>
        /// <param name="rhs">The RHS.</param>
        private static void Run_Dot_double(Tensor result, Tensor lhs, Tensor rhs)
        {
            unsafe
            {
                var resultPtr = (double *)CpuNativeHelpers.GetBufferStart(result);
                var lhsPtr    = (double *)CpuNativeHelpers.GetBufferStart(lhs);
                var rhsPtr    = (double *)CpuNativeHelpers.GetBufferStart(rhs);

                int n         = (int)lhs.Shape[0];
                int incx      = (int)lhs.Strides[0];
                int incy      = (int)rhs.Strides[0];
                *   resultPtr = OpenBlasNative.ddot_(&n, lhsPtr, &incx, rhsPtr, &incy);
            }
        }
        private static void Run_Dot_float(Tensor result, Tensor lhs, Tensor rhs)
        {
            unsafe
            {
                float *resultPtr = (float *)CpuNativeHelpers.GetBufferStart(result);
                float *lhsPtr    = (float *)CpuNativeHelpers.GetBufferStart(lhs);
                float *rhsPtr    = (float *)CpuNativeHelpers.GetBufferStart(rhs);

                int n         = (int)lhs.Sizes[0];
                int incx      = (int)lhs.Strides[0];
                int incy      = (int)rhs.Strides[0];
                *   resultPtr = OpenBlasNative.sdot_(&n, lhsPtr, &incx, rhsPtr, &incy);
            }
        }
        /// <summary>
        /// Runs the dot float.
        /// </summary>
        /// <param name="result">The result.</param>
        /// <param name="lhs">The LHS.</param>
        /// <param name="rhs">The RHS.</param>
        private static void Run_Dot_float(NDArray result, NDArray lhs, NDArray rhs)
        {
            unsafe
            {
                var resultPtr = (float *)CpuNativeHelpers.GetBufferStart(result);
                var lhsPtr    = (float *)CpuNativeHelpers.GetBufferStart(lhs);
                var rhsPtr    = (float *)CpuNativeHelpers.GetBufferStart(rhs);

                int n         = (int)lhs.Shape[0];
                int incx      = (int)lhs.Strides[0];
                int incy      = (int)rhs.Strides[0];
                *   resultPtr = OpenBlasNative.sdot_(&n, lhsPtr, &incx, rhsPtr, &incy);
            }
        }