public static extern cublasStatus_t cublasSgemm(cublasHandle_t handle, cublasOperation_t transA, cublasOperation_t transB, int m, int n, int k, float *alpha, [MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CudaMarshaler))] float[] matrixA, int lda, [MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CudaMarshaler))] float[] matrixB, int ldb, float *beta, [MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CudaMarshaler))] float[] matrixC, int ldc);
public static extern cublasStatus_t cublasCgemm( IntPtr handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, ref cuComplex alpha, IntPtr A, int lda, IntPtr B, int ldb, ref cuComplex beta, IntPtr C, int ldc );
public static extern cublasStatus_t cublasDgemm_v2( IntPtr handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, ref double alpha, IntPtr A, int lda, IntPtr B, int ldb, ref double beta, IntPtr C, int ldc );
public static extern cublasStatus_t cublasSgemm( IntPtr handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, ref float alpha, IntPtr A, int lda, IntPtr B, int ldb, ref float beta, IntPtr C, int ldc );
static void Main(string[] args) { const int N = 1024; Console.WriteLine("Execution cublas matrix mul with sizes ({0}, {1}) x ({2}, {3})", N, N, N, N); NaiveMatrix matrixA = new NaiveMatrix(N, N); NaiveMatrix matrixB = new NaiveMatrix(N, N); NaiveMatrix res = new NaiveMatrix(N, N); NaiveMatrix res_net = new NaiveMatrix(N, N); matrixA.FillMatrix(); matrixB.FillMatrix(); float alpha = 1.0f; float beta = 0.0f; cublas cublas = new cublas(); cublasHandle_t handle; cublas.Create(out handle); cublasOperation_t transA = cublasOperation_t.CUBLAS_OP_N; cublasOperation_t transB = cublasOperation_t.CUBLAS_OP_N; cublasSgemm(handle, transA, transB, N, N, N, &alpha, matrixA.Values, N, matrixB.Values, N, &beta, res.Values, N); cublas.Destroy(handle); reference(matrixA, matrixB, res_net, N); for (int i = 0; i < N * N; ++i) { if (Math.Abs(res[i] - res_net[i]) >= 1.0E-3) { Console.WriteLine("Error at {0}, expected {1}, got {2}", i, res_net[i], res[i]); Environment.Exit(1); } } Console.Out.WriteLine("DONE"); }