public static void Einsum( void *A_d, int nmodeA, int *modeA, long *extentA, long *strideA, CudaDataType typeA, void *B_d, int nmodeB, int *modeB, long *extentB, long *strideB, CudaDataType typeB, void *C_d, int nmodeC, int *modeC, long *extentC, long *strideC, CudaDataType typeC, void *D_d, int nmodeD, int *modeD, long *extentD, long *strideD, CudaDataType typeD, void *alpha, void *beta, CutensorComputeType typeCompute) { int res = Einsum_Native(A_d, nmodeA, modeA, extentA, strideA, (int)typeA, B_d, nmodeB, modeB, extentB, strideB, (int)typeB, C_d, nmodeC, modeC, extentC, strideC, (int)typeC, D_d, nmodeD, modeD, extentD, strideD, (int)typeD, alpha, beta, (int)typeCompute); if (res != 0) { throw new Exception("Cutensor Einsum Error!"); } }
public static void Einsum( void *A_d, int nmodeA, int *modeA, long *extentA, long *strideA, CudaDataType typeA, void *B_d, int nmodeB, int *modeB, long *extentB, long *strideB, CudaDataType typeB, void *C_d, int nmodeC, int *modeC, long *extentC, long *strideC, CudaDataType typeC, void *D_d, int nmodeD, int *modeD, long *extentD, long *strideD, CudaDataType typeD, double alphaval, double betaval, CutensorComputeType typeCompute) { if (typeCompute == CutensorComputeType.CUTENSOR_COMPUTE_16F) { Half alpha = (Half)alphaval; Half beta = (Half)betaval; Einsum(A_d, nmodeA, modeA, extentA, strideA, typeA, B_d, nmodeB, modeB, extentB, strideB, typeB, C_d, nmodeC, modeC, extentC, strideC, typeC, D_d, nmodeD, modeD, extentD, strideD, typeD, &alpha, &beta, typeCompute); } else if (typeCompute == CutensorComputeType.CUTENSOR_COMPUTE_32F) { float alpha = (float)alphaval; float beta = (float)betaval; Einsum(A_d, nmodeA, modeA, extentA, strideA, typeA, B_d, nmodeB, modeB, extentB, strideB, typeB, C_d, nmodeC, modeC, extentC, strideC, typeC, D_d, nmodeD, modeD, extentD, strideD, typeD, &alpha, &beta, typeCompute); } else if (typeCompute == CutensorComputeType.CUTENSOR_COMPUTE_64F) { double alpha = (double)alphaval; double beta = (double)betaval; Einsum(A_d, nmodeA, modeA, extentA, strideA, typeA, B_d, nmodeB, modeB, extentB, strideB, typeB, C_d, nmodeC, modeC, extentC, strideC, typeC, D_d, nmodeD, modeD, extentD, strideD, typeD, &alpha, &beta, typeCompute); } else if (typeCompute == CutensorComputeType.CUTENSOR_COMPUTE_32I) { int alpha = (int)alphaval; int beta = (int)betaval; Einsum(A_d, nmodeA, modeA, extentA, strideA, typeA, B_d, nmodeB, modeB, extentB, strideB, typeB, C_d, nmodeC, modeC, extentC, strideC, typeC, D_d, nmodeD, modeD, extentD, strideD, typeD, &alpha, &beta, typeCompute); } }