public static void MultiplyTrafoTest0() // two summation indices (k,r) { foreach (int K in new int[] { 210 }) // test unrolling and standard path { int I = 120; int M = 43; MultidimensionalArray A = MultidimensionalArray.Create(I, K, 2 * M); MultidimensionalArray B = MultidimensionalArray.Create(2 * M, K);; Console.WriteLine("number of operands in A: " + A.Length); Console.WriteLine("number of operands in B: " + B.Length); int[] mTrafo = new int[M]; MultidimensionalArray ResTst1 = MultidimensionalArray.Create(I, M); MultidimensionalArray ResTst2 = MultidimensionalArray.Create(I, M); MultidimensionalArray ResChck = MultidimensionalArray.Create(I, M); // fill operands with random values Random rnd = new Random(); A.ApplyAll(x => rnd.NextDouble()); B.ApplyAll(x => rnd.NextDouble()); ResTst1.ApplyAll(x => rnd.NextDouble()); ResChck.Set(ResTst1); ResTst2.Set(ResTst1); for (int m = 0; m < M; m++) { mTrafo[m] = rnd.Next(2 * M); //mTrafo[m] = m; Debug.Assert(mTrafo[m] < 2 * M); } double alpha = 0.67; double beta = 1.3; var mp1 = MultidimensionalArray.MultiplyProgram.Compile("im", "ikT(m)", "T(m)k", true); var mp2 = MultidimensionalArray.MultiplyProgram.Compile("im", "T(m)k", "ikT(m)", true); // tensorized multiplication: Stopwatch TenMult = new Stopwatch(); TenMult.Start(); ResTst1.Multiply(alpha, A, B, beta, ref mp1, mTrafo); TenMult.Stop(); ResTst2.Multiply(alpha, B, A, beta, ref mp2, mTrafo); Console.WriteLine("runtime of tensorized multiplication: " + TenMult.ElapsedMilliseconds + " millisec."); // comparison code Stopwatch RefMult = new Stopwatch(); RefMult.Start(); double errSum = 0; for (int i = 0; i < I; i++) { for (int m = 0; m < M; m++) { int m_trf = mTrafo[m]; // summation: double sum = 0; for (int k = 0; k < K; k++) { sum += A[i, k, m_trf] * B[m_trf, k]; } ResChck[i, m] = sum * alpha + ResChck[i, m] * beta; errSum += Math.Abs(ResTst1[i, m] - ResChck[i, m]); errSum += Math.Abs(ResTst2[i, m] - ResChck[i, m]); } } RefMult.Stop(); Console.WriteLine("runtime of loop multiplication: " + RefMult.ElapsedMilliseconds + " millisec."); Console.WriteLine("total error: " + errSum); double thres = 1.0e-6; Assert.IsTrue(errSum < thres); } }
public static void MultiplyTest3() // two summation indices (k,r) { foreach (int K in new int[] { 2, 21 }) // test unrolling and standard path { int I = 12; int M = 43; int N = 63; int R = 21; MultidimensionalArray A = MultidimensionalArray.Create(I, R, K, M); MultidimensionalArray B = MultidimensionalArray.Create(I, K, N, R); Console.WriteLine("number of operands in A: " + A.Length); Console.WriteLine("number of operands in B: " + B.Length); MultidimensionalArray ResTst1 = MultidimensionalArray.Create(I, M, N); MultidimensionalArray ResTst2 = MultidimensionalArray.Create(I, M, N); MultidimensionalArray ResChck = MultidimensionalArray.Create(I, M, N); // fill operands with random values Random rnd = new Random(); A.ApplyAll(x => rnd.NextDouble()); B.ApplyAll(x => rnd.NextDouble()); ResTst1.ApplyAll(x => rnd.NextDouble()); ResChck.Set(ResTst1); ResTst2.Set(ResTst1); double alpha = 0.67; double beta = 1.3; // tensorized multiplication: Stopwatch TenMult = new Stopwatch(); TenMult.Start(); ResTst1.Multiply(alpha, A, B, beta, "imn", "irkm", "iknr"); TenMult.Stop(); ResTst2.Multiply(alpha, B, A, beta, "imn", "iknr", "irkm"); Console.WriteLine("runtime of tensorized multiplication: " + TenMult.ElapsedMilliseconds + " millisec."); // comparison code Stopwatch RefMult = new Stopwatch(); RefMult.Start(); double errSum = 0; for (int i = 0; i < I; i++) { for (int n = 0; n < N; n++) { for (int m = 0; m < M; m++) { // summation: double sum = 0; for (int r = 0; r < R; r++) { for (int k = 0; k < K; k++) { sum += A[i, r, k, m] * B[i, k, n, r]; } } ResChck[i, m, n] = sum * alpha + ResChck[i, m, n] * beta; errSum += Math.Abs(ResTst1[i, m, n] - ResChck[i, m, n]); errSum += Math.Abs(ResTst2[i, m, n] - ResChck[i, m, n]); } } } RefMult.Stop(); Console.WriteLine("runtime of loop multiplication: " + RefMult.ElapsedMilliseconds + " millisec."); Console.WriteLine("total error: " + errSum); double thres = 1.0e-6; Assert.IsTrue(errSum < thres); } }