/// <summary> /// Runs the m v float. /// </summary> /// <param name="result">The result.</param> /// <param name="mat">The mat.</param> /// <param name="vec">The vec.</param> /// <exception cref="ArgumentException">lhs must be contiguous in the last dimension</exception> private static void Run_M_V_float(NDArray result, NDArray mat, NDArray vec) { // Require lhs to be row-major. This means we must tell BLAS to IntTranspose it (BLAS expects column-major matrices) if (mat.Strides[1] != 1) { throw new ArgumentException("lhs must be contiguous in the last dimension"); } unsafe { var yPtr = (float *)CpuNativeHelpers.GetBufferStart(result); var aPtr = (float *)CpuNativeHelpers.GetBufferStart(mat); var xPtr = (float *)CpuNativeHelpers.GetBufferStart(vec); byte trans = (byte)'t'; int m = (int)mat.Shape[1]; int n = (int)mat.Shape[0]; int incx = (int)vec.Strides[0]; int lda = (int)mat.Strides[0]; int incy = (int)result.Strides[0]; float alpha = 1; float beta = 0; OpenBlasNative.sgemv_(&trans, &m, &n, &alpha, aPtr, &lda, xPtr, &incx, &beta, yPtr, &incy); } }
/// <summary> /// Runs the m v double. /// </summary> /// <param name="result">The result.</param> /// <param name="lhs">The LHS.</param> /// <param name="rhs">The RHS.</param> /// <exception cref="ArgumentException">lhs must be contiguous in the last dimension</exception> private static void Run_M_V_double(NDArray result, NDArray lhs, NDArray rhs) { // Require lhs to be row-major. This means we must tell BLAS to IntTranspose it (BLAS expects column-major matrices) if (lhs.Strides[1] != 1) { throw new ArgumentException("lhs must be contiguous in the last dimension"); } unsafe { var resultPtr = (double *)CpuNativeHelpers.GetBufferStart(result); var lhsPtr = (double *)CpuNativeHelpers.GetBufferStart(lhs); var rhsPtr = (double *)CpuNativeHelpers.GetBufferStart(rhs); byte trans = (byte)'t'; int m = (int)rhs.Shape[1]; int n = (int)lhs.Shape[0]; int lda = (int)rhs.Strides[0]; int ldb = (int)lhs.Strides[0]; int ldc = (int)result.Strides[0]; double alpha = 1; double beta = 0; OpenBlasNative.dgemv_(&trans, &m, &n, &alpha, rhsPtr, &lda, lhsPtr, &ldb, &beta, resultPtr, &ldc); } }
/// <summary> /// Gemms the op. /// </summary> /// <param name="transA">The trans a.</param> /// <param name="transB">The trans b.</param> /// <param name="alpha">The alpha.</param> /// <param name="a">a.</param> /// <param name="b">The b.</param> /// <param name="beta">The beta.</param> /// <param name="c">The c.</param> /// <exception cref="ArgumentException"> /// a must be contiguous in the first dimension (column major / fortran order) /// or /// b must be contiguous in the first dimension (column major / fortran order) /// or /// c must be contiguous in the first dimension (column major / fortran order) /// </exception> /// <exception cref="NotSupportedException">CPU GEMM with element type " + c.ElementType + " not supported</exception> private static void GemmOp(BlasOp transA, BlasOp transB, float alpha, NDArray a, NDArray b, float beta, NDArray c) { if (a.Strides[0] != 1) { throw new ArgumentException("a must be contiguous in the first dimension (column major / fortran order)"); } if (b.Strides[0] != 1) { throw new ArgumentException("b must be contiguous in the first dimension (column major / fortran order)"); } if (c.Strides[0] != 1) { throw new ArgumentException("c must be contiguous in the first dimension (column major / fortran order)"); } unsafe { // dimensons: (m x k) * (k * n) = (m x n) bool nta = transA == BlasOp.NonTranspose; bool ntb = transB == BlasOp.NonTranspose; byte transa = (byte)transA; byte transb = (byte)transB; int m = (int)a.Shape[nta ? 0 : 1]; int k = (int)b.Shape[ntb ? 0 : 1]; int n = (int)b.Shape[ntb ? 1 : 0]; int lda = (int)a.Strides[1]; int ldb = (int)b.Strides[1]; int ldc = (int)c.Strides[1]; if (c.ElementType == DType.Float32) { var aPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(a); var bPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(b); var cPtrSingle = (float *)CpuNativeHelpers.GetBufferStart(c); OpenBlasNative.sgemm_(&transa, &transb, &m, &n, &k, &alpha, aPtrSingle, &lda, bPtrSingle, &ldb, &beta, cPtrSingle, &ldc); } else if (c.ElementType == DType.Float64) { var aPtrDouble = (double *)CpuNativeHelpers.GetBufferStart(a); var bPtrDouble = (double *)CpuNativeHelpers.GetBufferStart(b); var cPtrDouble = (double *)CpuNativeHelpers.GetBufferStart(c); var alphaDouble = (double)alpha; var betaDouble = (double)beta; OpenBlasNative.dgemm_(&transa, &transb, &m, &n, &k, &alphaDouble, aPtrDouble, &lda, bPtrDouble, &ldb, &betaDouble, cPtrDouble, &ldc); } else { throw new NotSupportedException("CPU GEMM with element type " + c.ElementType + " not supported"); } } }
/// <summary> /// Runs the dot double. /// </summary> /// <param name="result">The result.</param> /// <param name="lhs">The LHS.</param> /// <param name="rhs">The RHS.</param> private static void Run_Dot_double(NDArray result, NDArray lhs, NDArray rhs) { unsafe { var resultPtr = (double *)CpuNativeHelpers.GetBufferStart(result); var lhsPtr = (double *)CpuNativeHelpers.GetBufferStart(lhs); var rhsPtr = (double *)CpuNativeHelpers.GetBufferStart(rhs); int n = (int)lhs.Shape[0]; int incx = (int)lhs.Strides[0]; int incy = (int)rhs.Strides[0]; * resultPtr = OpenBlasNative.ddot_(&n, lhsPtr, &incx, rhsPtr, &incy); } }