public static double2[] Zgemm(int m, int n, int k, double2 alpha, double2[] A, int lda, double2[] B, int ldb, double2 beta, double2[] C, int ldc) { using (var dA = Worker.Malloc(A)) using (var dB = Worker.Malloc(B)) using (var dC = Worker.Malloc(C)) using (var dAlpha = Worker.Malloc(new[] { alpha })) using (var dBeta = Worker.Malloc(new[] { beta })) { Cublas.Zgemm(Transa, Transb, m, n, k, dAlpha.Ptr, dA.Ptr, lda, dB.Ptr, ldb, dBeta.Ptr, dC.Ptr, ldc); return(dC.Gather()); } }
public static double2[][] Zgemm(int m, int n, int k, double2 alpha, double2[][] hAs, int lda, double2[][] hBs, int ldb, double2 beta, double2[][] hCs, int ldc) { var batchCount = hAs.Length; var dAs = (from hA in hAs select Worker.Malloc(hA)).ToArray(); var dBs = (from hB in hBs select Worker.Malloc(hB)).ToArray(); var dCs = (from hC in hCs select Worker.Malloc(hC)).ToArray(); try { using (var dAPtrs = Worker.Malloc((from dA in dAs select dA.Ptr).ToArray())) using (var dBPtrs = Worker.Malloc((from dB in dBs select dB.Ptr).ToArray())) using (var dCPtrs = Worker.Malloc((from dC in dCs select dC.Ptr).ToArray())) using (var dAlpha = Worker.Malloc(new[] { alpha })) using (var dBeta = Worker.Malloc(new[] { beta })) { Cublas.Zgemm(Transa, Transb, m, n, k, dAlpha.Ptr, dAPtrs.Ptr, lda, dBPtrs.Ptr, ldb, dBeta.Ptr, dCPtrs.Ptr, ldc, batchCount); } return((from dC in dCs select dC.Gather()).ToArray()); } finally { foreach (var dA in dAs) { dA.Dispose(); } foreach (var dB in dBs) { dB.Dispose(); } foreach (var dC in dCs) { dC.Dispose(); } } }