|
1 | 1 | #ifndef CP_ALGO_MATH_CVECTOR_HPP
|
2 | 2 | #define CP_ALGO_MATH_CVECTOR_HPP
|
| 3 | +#include "../util/simd.hpp" |
3 | 4 | #include "../util/complex.hpp"
|
4 | 5 | #include "../util/checkpoint.hpp"
|
5 | 6 | #include "../util/big_alloc.hpp"
|
6 |
| -#include <experimental/simd> |
7 | 7 | #include <ranges>
|
8 | 8 |
|
9 | 9 | namespace stdx = std::experimental;
|
10 | 10 | namespace cp_algo::math::fft {
|
11 |
| - using ftype = double; |
12 | 11 | static constexpr size_t flen = 4;
|
13 |
| - static constexpr size_t bytes = flen * sizeof(ftype); |
| 12 | + using ftype = double; |
| 13 | + using vftype = simd<ftype, flen>; |
14 | 14 | using point = complex<ftype>;
|
15 |
| - using vftype [[gnu::vector_size(bytes)]] = ftype; |
16 | 15 | using vpoint = complex<vftype>;
|
17 | 16 | static constexpr vftype vz = {};
|
18 | 17 | vpoint vi(vpoint const& r) {
|
19 | 18 | return {-imag(r), real(r)};
|
20 | 19 | }
|
21 |
| - vftype abs(vftype a) { |
22 |
| - return a < 0 ? -a : a; |
23 |
| - } |
24 |
| - using i64x4 [[gnu::vector_size(bytes)]] = int64_t; |
25 |
| - using u64x4 [[gnu::vector_size(bytes)]] = uint64_t; |
26 |
| - auto lround(vftype a) { |
27 |
| - return __builtin_convertvector(a < 0 ? a - 0.5 : a + 0.5, i64x4); |
28 |
| - } |
29 |
| - auto round(vftype a) { |
30 |
| - return __builtin_convertvector(lround(a), vftype); |
31 |
| - } |
32 |
| - u64x4 montgomery_reduce(u64x4 x, u64x4 mod, u64x4 imod) { |
33 |
| - auto x_ninv = _mm256_mul_epu32(__m256i(x), __m256i(imod)); |
34 |
| - auto x_res = _mm256_add_epi64(__m256i(x), _mm256_mul_epu32(x_ninv, __m256i(mod))); |
35 |
| - return u64x4(_mm256_bsrli_epi128(x_res, 4)); |
36 |
| - } |
37 |
| - u64x4 montgomery_mul(u64x4 x, u64x4 y, u64x4 mod, u64x4 imod) { |
38 |
| - return montgomery_reduce(u64x4(_mm256_mul_epu32(__m256i(x), __m256i(y))), mod, imod); |
39 |
| - } |
40 | 20 |
|
41 | 21 | struct cvector {
|
42 | 22 | std::vector<vpoint, big_alloc<vpoint>> r;
|
@@ -99,8 +79,7 @@ namespace cp_algo::math::fft {
|
99 | 79 | }
|
100 | 80 | template<int step>
|
101 | 81 | static void exec_on_eval(size_t n, size_t k, auto &&callback) {
|
102 |
| - point factor = root(4 * step * n); |
103 |
| - callback(factor * eval_point(step * k)); |
| 82 | + callback(root(4 * step * n) * eval_point(step * k)); |
104 | 83 | }
|
105 | 84 |
|
106 | 85 | void dot(cvector const& t) {
|
|
0 commit comments