@@ -29,20 +29,37 @@ namespace cp_algo::math::fft {
29
29
30
30
dft (auto const & a, size_t n): A(n), B(n) {
31
31
init ();
32
- base cur = factor;
33
- base step = bpow (factor, n);
34
- for (size_t i = 0 ; i < std::min (n, size (a)); i++) {
32
+ base b2x32 = bpow (base (2 ), 32 );
33
+ u64x4 cur = {
34
+ (bpow (factor, 1 ) * b2x32).getr (),
35
+ (bpow (factor, 2 ) * b2x32).getr (),
36
+ (bpow (factor, 3 ) * b2x32).getr (),
37
+ (bpow (factor, 4 ) * b2x32).getr ()
38
+ };
39
+ u64x4 step4 = u64x4{} + (bpow (factor, 4 ) * b2x32).getr ();
40
+ u64x4 stepn = u64x4{} + (bpow (factor, n) * b2x32).getr ();
41
+ for (size_t i = 0 ; i < std::min (n, size (a)); i += flen) {
35
42
auto splt = [&](size_t i, auto mul) {
36
- auto ai = i < size (a) ? (a[i] * mul).rem () : 0 ;
37
- auto quo = ai / split;
38
- auto rem = ai % split;
39
- return std::pair{(ftype)rem, (ftype)quo};
43
+ if (i >= size (a)) {
44
+ return std::pair{vftype (), vftype ()};
45
+ }
46
+ u64x4 au = {
47
+ i < size (a) ? a[i].getr () : 0 ,
48
+ i + 1 < size (a) ? a[i + 1 ].getr () : 0 ,
49
+ i + 2 < size (a) ? a[i + 2 ].getr () : 0 ,
50
+ i + 3 < size (a) ? a[i + 3 ].getr () : 0
51
+ };
52
+ au = montgomery_mul (au, mul, mod, imod);
53
+ au = au >= base::mod () ? au - base::mod () : au;
54
+ auto ai = i64x4 (au);
55
+ ai = ai >= base::mod () / 2 ? ai - base::mod () : ai;
56
+ return std::pair{to_double (ai % split), to_double (ai / split)};
40
57
};
41
58
auto [rai, qai] = splt (i, cur);
42
- auto [rani, qani] = splt (n + i, cur * step );
43
- A.set (i, point (rai, rani) );
44
- B.set (i, point (qai, qani) );
45
- cur *= factor ;
59
+ auto [rani, qani] = splt (n + i, montgomery_mul ( cur, stepn, mod, imod) );
60
+ A.at (i) = vpoint (rai, rani);
61
+ B.at (i) = vpoint (qai, qani);
62
+ cur = montgomery_mul (cur, step4, mod, imod) ;
46
63
}
47
64
checkpoint (" dft init" );
48
65
if (n) {
0 commit comments