Skip to content

Commit 1694eb5

Browse files
committed
vectorize dft init
1 parent 1fbc92f commit 1694eb5

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

cp-algo/math/fft.hpp

+28-11
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,37 @@ namespace cp_algo::math::fft {
2929

3030
dft(auto const& a, size_t n): A(n), B(n) {
3131
init();
32-
base cur = factor;
33-
base step = bpow(factor, n);
34-
for(size_t i = 0; i < std::min(n, size(a)); i++) {
32+
base b2x32 = bpow(base(2), 32);
33+
u64x4 cur = {
34+
(bpow(factor, 1) * b2x32).getr(),
35+
(bpow(factor, 2) * b2x32).getr(),
36+
(bpow(factor, 3) * b2x32).getr(),
37+
(bpow(factor, 4) * b2x32).getr()
38+
};
39+
u64x4 step4 = u64x4{} + (bpow(factor, 4) * b2x32).getr();
40+
u64x4 stepn = u64x4{} + (bpow(factor, n) * b2x32).getr();
41+
for(size_t i = 0; i < std::min(n, size(a)); i += flen) {
3542
auto splt = [&](size_t i, auto mul) {
36-
auto ai = i < size(a) ? (a[i] * mul).rem() : 0;
37-
auto quo = ai / split;
38-
auto rem = ai % split;
39-
return std::pair{(ftype)rem, (ftype)quo};
43+
if(i >= size(a)) {
44+
return std::pair{vftype(), vftype()};
45+
}
46+
u64x4 au = {
47+
i < size(a) ? a[i].getr() : 0,
48+
i + 1 < size(a) ? a[i + 1].getr() : 0,
49+
i + 2 < size(a) ? a[i + 2].getr() : 0,
50+
i + 3 < size(a) ? a[i + 3].getr() : 0
51+
};
52+
au = montgomery_mul(au, mul, mod, imod);
53+
au = au >= base::mod() ? au - base::mod() : au;
54+
auto ai = i64x4(au);
55+
ai = ai >= base::mod() / 2 ? ai - base::mod() : ai;
56+
return std::pair{to_double(ai % split), to_double(ai / split)};
4057
};
4158
auto [rai, qai] = splt(i, cur);
42-
auto [rani, qani] = splt(n + i, cur * step);
43-
A.set(i, point(rai, rani));
44-
B.set(i, point(qai, qani));
45-
cur *= factor;
59+
auto [rani, qani] = splt(n + i, montgomery_mul(cur, stepn, mod, imod));
60+
A.at(i) = vpoint(rai, rani);
61+
B.at(i) = vpoint(qai, qani);
62+
cur = montgomery_mul(cur, step4, mod, imod);
4663
}
4764
checkpoint("dft init");
4865
if(n) {

cp-algo/util/simd.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@ namespace cp_algo {
1616
return a < 0 ? -a : a;
1717
}
1818

19+
// https://stackoverflow.com/a/77376595
20+
// works for ints in (-2^51, 2^51)
21+
static constexpr dx4 magic = dx4() + (3ULL << 51);
1922
[[gnu::always_inline]] inline i64x4 lround(dx4 x) {
20-
// https://stackoverflow.com/a/77376595
21-
static constexpr dx4 magic = dx4() + double(3ULL << 51);
2223
return i64x4(x + magic) - i64x4(magic);
2324
}
25+
[[gnu::always_inline]] inline dx4 to_double(i64x4 x) {
26+
return dx4(x + i64x4(magic)) - magic;
27+
}
2428

2529
[[gnu::always_inline]] inline dx4 round(dx4 a) {
2630
#ifdef __AVX2__

0 commit comments

Comments
 (0)