private static void GemmOp(BlasOp transA, BlasOp transB, float alpha, Tensor a, Tensor b, float beta, Tensor c) { if (a.Strides[0] != 1) { throw new ArgumentException("a must be contiguous in the first dimension (column major / fortran order)"); } if (b.Strides[0] != 1) { throw new ArgumentException("b must be contiguous in the first dimension (column major / fortran order)"); } if (c.Strides[0] != 1) { throw new ArgumentException("c must be contiguous in the first dimension (column major / fortran order)"); } unsafe { // dimensons: (m x k) * (k * n) = (m x n) bool nta = transA == BlasOp.NonTranspose; bool ntb = transB == BlasOp.NonTranspose; byte transa = (byte)transA; byte transb = (byte)transB; int m = (int)a.Sizes[nta ? 0 : 1]; int k = (int)b.Sizes[ntb ? 0 : 1]; int n = (int)b.Sizes[ntb ? 1 : 0]; int lda = (int)a.Strides[1]; int ldb = (int)b.Strides[1]; int ldc = (int)c.Strides[1]; if (c.ElementType == DType.Float32) { float *aPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(a); float *bPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(b); float *cPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(c); SGEMM sgemm = new SGEMM(); sgemm.Run(System.Text.ASCIIEncoding.ASCII.GetString(&transa, 1), System.Text.ASCIIEncoding.ASCII.GetString(&transb, 1), m, n, k, alpha, aPtrSingle, lda, bPtrSingle, ldb, beta, cPtrSingle, ldc); } else if (c.ElementType == DType.Float64) { double *aPtrDouble = (double *)CpuNativeHelpers.GetBufferStart(a); double *bPtrDouble = (double *)CpuNativeHelpers.GetBufferStart(b); double *cPtrDouble = (double *)CpuNativeHelpers.GetBufferStart(c); double alphaDouble = alpha; double betaDouble = beta; OpenBlasNative.dgemm_(&transa, &transb, &m, &n, &k, &alphaDouble, aPtrDouble, &lda, bPtrDouble, &ldb, &betaDouble, cPtrDouble, &ldc); } else { throw new NotSupportedException("CPU GEMM with element type " + c.ElementType + " not supported"); } } }
private static void GemmOp(BlasOp transA, BlasOp transB, float alpha, Tensor a, Tensor b, float beta, Tensor c) { if (a.Strides[0] != 1) { throw new ArgumentException("a must be contiguous in the first dimension (column major / fortran order)"); } if (b.Strides[0] != 1) { throw new ArgumentException("b must be contiguous in the first dimension (column major / fortran order)"); } if (c.Strides[0] != 1) { throw new ArgumentException("c must be contiguous in the first dimension (column major / fortran order)"); } unsafe { // dimensons: (m x k) * (k * n) = (m x n) var nta = transA == BlasOp.NonTranspose; var ntb = transB == BlasOp.NonTranspose; var transa = (byte)transA; var transb = (byte)transB; var m = (int)a.Sizes[nta ? 0 : 1]; var k = (int)b.Sizes[ntb ? 0 : 1]; var n = (int)b.Sizes[ntb ? 1 : 0]; var lda = (int)a.Strides[1]; var ldb = (int)b.Strides[1]; var ldc = (int)c.Strides[1]; if (c.ElementType == DType.Float32) { var aPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(a); var bPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(b); var cPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(c); OpenBlasNative.sgemm_(&transa, &transb, &m, &n, &k, &alpha, aPtrSingle, &lda, bPtrSingle, &ldb, &beta, cPtrSingle, &ldc); } else if (c.ElementType == DType.Float64) { var aPtrDouble = (double *)CpuNativeHelpers.GetBufferStart(a); var bPtrDouble = (double *)CpuNativeHelpers.GetBufferStart(b); var cPtrDouble = (double *)CpuNativeHelpers.GetBufferStart(c); double alphaDouble = alpha; double betaDouble = beta; OpenBlasNative.dgemm_(&transa, &transb, &m, &n, &k, &alphaDouble, aPtrDouble, &lda, bPtrDouble, &ldb, &betaDouble, cPtrDouble, &ldc); } else { throw new NotSupportedException("CPU GEMM with element type " + c.ElementType + " not supported"); } } }