public StaticModInt <T>[][] Strassen(StaticModInt <T>[][] mat1, StaticModInt <T>[][] mat2) { Contract.Assert(default(T).Mod % 2 == 1); Contract.Assert(mat1.Length <= S); Contract.Assert(mat2.Length <= S); Contract.Assert(mat1[0].Length <= S); Contract.Assert(mat2[0].Length <= S); var a = ToVectorize(mat1); var b = ToVectorize(mat2); var c = new VectorizedStaticModInt <T> [S8 * S]; var s = new VectorizedStaticModInt <T> [S8 * S * 3 / 2]; var t = new VectorizedStaticModInt <T> [S8 * S * 3 / 2]; var u = new VectorizedStaticModInt <T> [S8 * S * 3 / 2]; PlaceS(S, 0, 0, s.AsSpan(), a); PlaceT(S, 0, 0, t.AsSpan(), b); for (int i = 0; i < S * S8; i++) { s[i] = s[i].Itom(); } for (int i = 0; i < S * S8; i++) { t[i] = t[i].Itom(); } StrassenImpl(S, s, t, u.AsSpan()); for (int i = 0; i < S * S8; i++) { u[i] = u[i].Mtoi(); } PlaceRev(S, 0, 0, u.AsSpan(), c); return(ToMatrix(c, mat1.Length, mat2[0].Length)); }
private static void MulSimd(Span <VectorizedStaticModInt <T> > s, Span <VectorizedStaticModInt <T> > t, Span <VectorizedStaticModInt <T> > u) { for (int i = 0; i < B * B8; i++) { var cmpS = Avx2.CompareGreaterThan(s[i].Value.AsInt32(), VectorizedStaticModInt <T> .M1.AsInt32()).AsUInt32(); var cmpT = Avx2.CompareGreaterThan(t[i].Value.AsInt32(), VectorizedStaticModInt <T> .M1.AsInt32()).AsUInt32(); var difS = Avx2.And(cmpS, VectorizedStaticModInt <T> .M1); var difT = Avx2.And(cmpT, VectorizedStaticModInt <T> .M1); s[i] = Avx2.Subtract(s[i].Value, difS); t[i] = Avx2.Subtract(t[i].Value, difT); } var m1v = VectorizedStaticModInt <T> .M1.GetElement(0); var m2v = VectorizedStaticModInt <T> .M2.GetElement(0); var zero = new VectorizedStaticModInt <T>().Value; var th1 = new VectorizedStaticModInt <T>(0, m1v, 0, m1v, 0, m1v, 0, m1v).Value.AsInt64(); var th2 = new VectorizedStaticModInt <T>(0, m2v, 0, m2v, 0, m2v, 0, m2v).Value.AsInt64(); for (int i = 0; i < B; i += 8) { for (int j = 0; j < B8; j += 1) { Vector256 <ulong> prod0200 = default; Vector256 <ulong> prod1300 = default; Vector256 <ulong> prod0210 = default; Vector256 <ulong> prod1310 = default; Vector256 <ulong> prod0220 = default; Vector256 <ulong> prod1320 = default; Vector256 <ulong> prod0230 = default; Vector256 <ulong> prod1330 = default; Vector256 <ulong> prod0240 = default; Vector256 <ulong> prod1340 = default; Vector256 <ulong> prod0250 = default; Vector256 <ulong> prod1350 = default; Vector256 <ulong> prod0260 = default; Vector256 <ulong> prod1360 = default; Vector256 <ulong> prod0270 = default; Vector256 <ulong> prod1370 = default; for (int k = 0; k < B; k += 8) { for (int l = 0; l < 8; l++) { Vector256 <uint> T0 = t[j * B + k + l].Value; var T130 = Avx2.Shuffle(T0, 0xF5); var S00 = Vector256.Create(s[(i + 0) * B8 + k / 8].Value.GetElement(l)); var ST0200 = Avx2.Multiply(S00, T0); var ST1300 = Avx2.Multiply(S00, T130); prod0200 = Avx2.Add(prod0200, ST0200); prod1300 = Avx2.Add(prod1300, ST1300); var S10 = Vector256.Create(s[(i + 1) * B8 + k / 8].Value.GetElement(l)); var ST0210 = Avx2.Multiply(S10, T0); var ST1310 = Avx2.Multiply(S10, T130); prod0210 = Avx2.Add(prod0210, ST0210); prod1310 = Avx2.Add(prod1310, ST1310); var S20 = Vector256.Create(s[(i + 2) * B8 + k / 8].Value.GetElement(l)); var ST0220 = Avx2.Multiply(S20, T0); var ST1320 = Avx2.Multiply(S20, T130); prod0220 = Avx2.Add(prod0220, ST0220); prod1320 = Avx2.Add(prod1320, ST1320); var S30 = Vector256.Create(s[(i + 3) * B8 + k / 8].Value.GetElement(l)); var ST0230 = Avx2.Multiply(S30, T0); var ST1330 = Avx2.Multiply(S30, T130); prod0230 = Avx2.Add(prod0230, ST0230); prod1330 = Avx2.Add(prod1330, ST1330); var S40 = Vector256.Create(s[(i + 4) * B8 + k / 8].Value.GetElement(l)); var ST0240 = Avx2.Multiply(S40, T0); var ST1340 = Avx2.Multiply(S40, T130); prod0240 = Avx2.Add(prod0240, ST0240); prod1340 = Avx2.Add(prod1340, ST1340); var S50 = Vector256.Create(s[(i + 5) * B8 + k / 8].Value.GetElement(l)); var ST0250 = Avx2.Multiply(S50, T0); var ST1350 = Avx2.Multiply(S50, T130); prod0250 = Avx2.Add(prod0250, ST0250); prod1350 = Avx2.Add(prod1350, ST1350); var S60 = Vector256.Create(s[(i + 6) * B8 + k / 8].Value.GetElement(l)); var ST0260 = Avx2.Multiply(S60, T0); var ST1360 = Avx2.Multiply(S60, T130); prod0260 = Avx2.Add(prod0260, ST0260); prod1360 = Avx2.Add(prod1360, ST1360); var S70 = Vector256.Create(s[(i + 7) * B8 + k / 8].Value.GetElement(l)); var ST0270 = Avx2.Multiply(S70, T0); var ST1370 = Avx2.Multiply(S70, T130); prod0270 = Avx2.Add(prod0270, ST0270); prod1370 = Avx2.Add(prod1370, ST1370); } var cmp0200 = Avx2.CompareGreaterThan(zero.AsInt64(), prod0200.AsInt64()); var cmp1300 = Avx2.CompareGreaterThan(zero.AsInt64(), prod1300.AsInt64()); var dif0200 = Avx2.And(cmp0200, th2); var dif1300 = Avx2.And(cmp1300, th2); prod0200 = Avx2.Subtract(prod0200, dif0200.AsUInt64()); prod1300 = Avx2.Subtract(prod1300, dif1300.AsUInt64()); var cmp0210 = Avx2.CompareGreaterThan(zero.AsInt64(), prod0210.AsInt64()); var cmp1310 = Avx2.CompareGreaterThan(zero.AsInt64(), prod1310.AsInt64()); var dif0210 = Avx2.And(cmp0210, th2); var dif1310 = Avx2.And(cmp1310, th2); prod0210 = Avx2.Subtract(prod0210, dif0210.AsUInt64()); prod1310 = Avx2.Subtract(prod1310, dif1310.AsUInt64()); var cmp0220 = Avx2.CompareGreaterThan(zero.AsInt64(), prod0220.AsInt64()); var cmp1320 = Avx2.CompareGreaterThan(zero.AsInt64(), prod1320.AsInt64()); var dif0220 = Avx2.And(cmp0220, th2); var dif1320 = Avx2.And(cmp1320, th2); prod0220 = Avx2.Subtract(prod0220, dif0220.AsUInt64()); prod1320 = Avx2.Subtract(prod1320, dif1320.AsUInt64()); var cmp0230 = Avx2.CompareGreaterThan(zero.AsInt64(), prod0230.AsInt64()); var cmp1330 = Avx2.CompareGreaterThan(zero.AsInt64(), prod1330.AsInt64()); var dif0230 = Avx2.And(cmp0230, th2); var dif1330 = Avx2.And(cmp1330, th2); prod0230 = Avx2.Subtract(prod0230, dif0230.AsUInt64()); prod1330 = Avx2.Subtract(prod1330, dif1330.AsUInt64()); var cmp0240 = Avx2.CompareGreaterThan(zero.AsInt64(), prod0240.AsInt64()); var cmp1340 = Avx2.CompareGreaterThan(zero.AsInt64(), prod1340.AsInt64()); var dif0240 = Avx2.And(cmp0240, th2); var dif1340 = Avx2.And(cmp1340, th2); prod0240 = Avx2.Subtract(prod0240, dif0240.AsUInt64()); prod1340 = Avx2.Subtract(prod1340, dif1340.AsUInt64()); var cmp0250 = Avx2.CompareGreaterThan(zero.AsInt64(), prod0250.AsInt64()); var cmp1350 = Avx2.CompareGreaterThan(zero.AsInt64(), prod1350.AsInt64()); var dif0250 = Avx2.And(cmp0250, th2); var dif1350 = Avx2.And(cmp1350, th2); prod0250 = Avx2.Subtract(prod0250, dif0250.AsUInt64()); prod1350 = Avx2.Subtract(prod1350, dif1350.AsUInt64()); var cmp0260 = Avx2.CompareGreaterThan(zero.AsInt64(), prod0260.AsInt64()); var cmp1360 = Avx2.CompareGreaterThan(zero.AsInt64(), prod1360.AsInt64()); var dif0260 = Avx2.And(cmp0260, th2); var dif1360 = Avx2.And(cmp1360, th2); prod0260 = Avx2.Subtract(prod0260, dif0260.AsUInt64()); prod1360 = Avx2.Subtract(prod1360, dif1360.AsUInt64()); var cmp0270 = Avx2.CompareGreaterThan(zero.AsInt64(), prod0270.AsInt64()); var cmp1370 = Avx2.CompareGreaterThan(zero.AsInt64(), prod1370.AsInt64()); var dif0270 = Avx2.And(cmp0270, th2); var dif1370 = Avx2.And(cmp1370, th2); prod0270 = Avx2.Subtract(prod0270, dif0270.AsUInt64()); prod1370 = Avx2.Subtract(prod1370, dif1370.AsUInt64()); } for (int _ = 0; _ < 2; _++) { var cmp02 = Avx2.CompareGreaterThan(prod0200.AsInt64(), th1); var cmp13 = Avx2.CompareGreaterThan(prod1300.AsInt64(), th1); var dif02 = Avx2.And(cmp02, th1); var dif13 = Avx2.And(cmp13, th1); prod0200 = Avx2.Subtract(prod0200, dif02.AsUInt64()); prod1300 = Avx2.Subtract(prod1300, dif13.AsUInt64()); } u[(i + 0) * B8 + j + 0] = VectorizedStaticModInt <T> .Reduce(prod0200.AsUInt32(), prod1300.AsUInt32()); for (int _ = 0; _ < 2; _++) { var cmp02 = Avx2.CompareGreaterThan(prod0210.AsInt64(), th1); var cmp13 = Avx2.CompareGreaterThan(prod1310.AsInt64(), th1); var dif02 = Avx2.And(cmp02, th1); var dif13 = Avx2.And(cmp13, th1); prod0210 = Avx2.Subtract(prod0210, dif02.AsUInt64()); prod1310 = Avx2.Subtract(prod1310, dif13.AsUInt64()); } u[(i + 1) * B8 + j + 0] = VectorizedStaticModInt <T> .Reduce(prod0210.AsUInt32(), prod1310.AsUInt32()); for (int _ = 0; _ < 2; _++) { var cmp02 = Avx2.CompareGreaterThan(prod0220.AsInt64(), th1); var cmp13 = Avx2.CompareGreaterThan(prod1320.AsInt64(), th1); var dif02 = Avx2.And(cmp02, th1); var dif13 = Avx2.And(cmp13, th1); prod0220 = Avx2.Subtract(prod0220, dif02.AsUInt64()); prod1320 = Avx2.Subtract(prod1320, dif13.AsUInt64()); } u[(i + 2) * B8 + j + 0] = VectorizedStaticModInt <T> .Reduce(prod0220.AsUInt32(), prod1320.AsUInt32()); for (int _ = 0; _ < 2; _++) { var cmp02 = Avx2.CompareGreaterThan(prod0230.AsInt64(), th1); var cmp13 = Avx2.CompareGreaterThan(prod1330.AsInt64(), th1); var dif02 = Avx2.And(cmp02, th1); var dif13 = Avx2.And(cmp13, th1); prod0230 = Avx2.Subtract(prod0230, dif02.AsUInt64()); prod1330 = Avx2.Subtract(prod1330, dif13.AsUInt64()); } u[(i + 3) * B8 + j + 0] = VectorizedStaticModInt <T> .Reduce(prod0230.AsUInt32(), prod1330.AsUInt32()); for (int _ = 0; _ < 2; _++) { var cmp02 = Avx2.CompareGreaterThan(prod0240.AsInt64(), th1); var cmp13 = Avx2.CompareGreaterThan(prod1340.AsInt64(), th1); var dif02 = Avx2.And(cmp02, th1); var dif13 = Avx2.And(cmp13, th1); prod0240 = Avx2.Subtract(prod0240, dif02.AsUInt64()); prod1340 = Avx2.Subtract(prod1340, dif13.AsUInt64()); } u[(i + 4) * B8 + j + 0] = VectorizedStaticModInt <T> .Reduce(prod0240.AsUInt32(), prod1340.AsUInt32()); for (int _ = 0; _ < 2; _++) { var cmp02 = Avx2.CompareGreaterThan(prod0250.AsInt64(), th1); var cmp13 = Avx2.CompareGreaterThan(prod1350.AsInt64(), th1); var dif02 = Avx2.And(cmp02, th1); var dif13 = Avx2.And(cmp13, th1); prod0250 = Avx2.Subtract(prod0250, dif02.AsUInt64()); prod1350 = Avx2.Subtract(prod1350, dif13.AsUInt64()); } u[(i + 5) * B8 + j + 0] = VectorizedStaticModInt <T> .Reduce(prod0250.AsUInt32(), prod1350.AsUInt32()); for (int _ = 0; _ < 2; _++) { var cmp02 = Avx2.CompareGreaterThan(prod0260.AsInt64(), th1); var cmp13 = Avx2.CompareGreaterThan(prod1360.AsInt64(), th1); var dif02 = Avx2.And(cmp02, th1); var dif13 = Avx2.And(cmp13, th1); prod0260 = Avx2.Subtract(prod0260, dif02.AsUInt64()); prod1360 = Avx2.Subtract(prod1360, dif13.AsUInt64()); } u[(i + 6) * B8 + j + 0] = VectorizedStaticModInt <T> .Reduce(prod0260.AsUInt32(), prod1360.AsUInt32()); for (int _ = 0; _ < 2; _++) { var cmp02 = Avx2.CompareGreaterThan(prod0270.AsInt64(), th1); var cmp13 = Avx2.CompareGreaterThan(prod1370.AsInt64(), th1); var dif02 = Avx2.And(cmp02, th1); var dif13 = Avx2.And(cmp13, th1); prod0270 = Avx2.Subtract(prod0270, dif02.AsUInt64()); prod1370 = Avx2.Subtract(prod1370, dif13.AsUInt64()); } u[(i + 7) * B8 + j + 0] = VectorizedStaticModInt <T> .Reduce(prod0270.AsUInt32(), prod1370.AsUInt32()); } } }