Ejemplo n.º 1
0
        private static void GemmOp(BlasOp transA, BlasOp transB, float alpha, Tensor a, Tensor b, float beta, Tensor c)
        {
            if (a.Strides[0] != 1)
            {
                throw new ArgumentException("a must be contiguous in the first dimension (column major / fortran order)");
            }

            if (b.Strides[0] != 1)
            {
                throw new ArgumentException("b must be contiguous in the first dimension (column major / fortran order)");
            }

            if (c.Strides[0] != 1)
            {
                throw new ArgumentException("c must be contiguous in the first dimension (column major / fortran order)");
            }

            unsafe
            {
                // dimensons: (m x k) * (k * n) = (m x n)
                bool nta    = transA == BlasOp.NonTranspose;
                bool ntb    = transB == BlasOp.NonTranspose;
                byte transa = (byte)transA;
                byte transb = (byte)transB;
                int  m      = (int)a.Sizes[nta ? 0 : 1];
                int  k      = (int)b.Sizes[ntb ? 0 : 1];
                int  n      = (int)b.Sizes[ntb ? 1 : 0];
                int  lda    = (int)a.Strides[1];
                int  ldb    = (int)b.Strides[1];
                int  ldc    = (int)c.Strides[1];

                if (c.ElementType == DType.Float32)
                {
                    float *aPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(a);
                    float *bPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(b);
                    float *cPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(c);

                    SGEMM sgemm = new SGEMM();
                    sgemm.Run(System.Text.ASCIIEncoding.ASCII.GetString(&transa, 1), System.Text.ASCIIEncoding.ASCII.GetString(&transb, 1), m, n, k, alpha, aPtrSingle, lda, bPtrSingle, ldb, beta, cPtrSingle, ldc);
                }
                else if (c.ElementType == DType.Float64)
                {
                    double *aPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(a);
                    double *bPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(b);
                    double *cPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(c);
                    double  alphaDouble = alpha;
                    double  betaDouble  = beta;
                    OpenBlasNative.dgemm_(&transa, &transb, &m, &n, &k, &alphaDouble, aPtrDouble, &lda, bPtrDouble, &ldb, &betaDouble, cPtrDouble, &ldc);
                }
                else
                {
                    throw new NotSupportedException("CPU GEMM with element type " + c.ElementType + " not supported");
                }
            }
        }
Ejemplo n.º 2
0
        private static void GemmOp(BlasOp transA, BlasOp transB, float alpha, Tensor a, Tensor b, float beta, Tensor c)
        {
            if (a.Strides[0] != 1)
            {
                throw new ArgumentException("a must be contiguous in the first dimension (column major / fortran order)");
            }

            if (b.Strides[0] != 1)
            {
                throw new ArgumentException("b must be contiguous in the first dimension (column major / fortran order)");
            }

            if (c.Strides[0] != 1)
            {
                throw new ArgumentException("c must be contiguous in the first dimension (column major / fortran order)");
            }

            unsafe
            {
                // dimensons: (m x k) * (k * n) = (m x n)
                var nta    = transA == BlasOp.NonTranspose;
                var ntb    = transB == BlasOp.NonTranspose;
                var transa = (byte)transA;
                var transb = (byte)transB;
                var m      = (int)a.Sizes[nta ? 0 : 1];
                var k      = (int)b.Sizes[ntb ? 0 : 1];
                var n      = (int)b.Sizes[ntb ? 1 : 0];
                var lda    = (int)a.Strides[1];
                var ldb    = (int)b.Strides[1];
                var ldc    = (int)c.Strides[1];

                if (c.ElementType == DType.Float32)
                {
                    var aPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(a);
                    var bPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(b);
                    var cPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(c);

                    OpenBlasNative.sgemm_(&transa, &transb, &m, &n, &k, &alpha, aPtrSingle, &lda, bPtrSingle, &ldb, &beta, cPtrSingle, &ldc);
                }
                else if (c.ElementType == DType.Float64)
                {
                    var    aPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(a);
                    var    bPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(b);
                    var    cPtrDouble  = (double *)CpuNativeHelpers.GetBufferStart(c);
                    double alphaDouble = alpha;
                    double betaDouble  = beta;
                    OpenBlasNative.dgemm_(&transa, &transb, &m, &n, &k, &alphaDouble, aPtrDouble, &lda, bPtrDouble, &ldb, &betaDouble, cPtrDouble, &ldc);
                }
                else
                {
                    throw new NotSupportedException("CPU GEMM with element type " + c.ElementType + " not supported");
                }
            }
        }