Skip to content

Commit 88a38e3

Browse files
committed
Factor out simd
1 parent 9dda1be commit 88a38e3

File tree

2 files changed

+65
-25
lines changed

2 files changed

+65
-25
lines changed

cp-algo/math/cvector.hpp

+4-25
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,22 @@
11
#ifndef CP_ALGO_MATH_CVECTOR_HPP
22
#define CP_ALGO_MATH_CVECTOR_HPP
3+
#include "../util/simd.hpp"
34
#include "../util/complex.hpp"
45
#include "../util/checkpoint.hpp"
56
#include "../util/big_alloc.hpp"
6-
#include <experimental/simd>
77
#include <ranges>
88

99
namespace stdx = std::experimental;
1010
namespace cp_algo::math::fft {
11-
using ftype = double;
1211
static constexpr size_t flen = 4;
13-
static constexpr size_t bytes = flen * sizeof(ftype);
12+
using ftype = double;
13+
using vftype = simd<ftype, flen>;
1414
using point = complex<ftype>;
15-
using vftype [[gnu::vector_size(bytes)]] = ftype;
1615
using vpoint = complex<vftype>;
1716
static constexpr vftype vz = {};
1817
vpoint vi(vpoint const& r) {
1918
return {-imag(r), real(r)};
2019
}
21-
vftype abs(vftype a) {
22-
return a < 0 ? -a : a;
23-
}
24-
using i64x4 [[gnu::vector_size(bytes)]] = int64_t;
25-
using u64x4 [[gnu::vector_size(bytes)]] = uint64_t;
26-
auto lround(vftype a) {
27-
return __builtin_convertvector(a < 0 ? a - 0.5 : a + 0.5, i64x4);
28-
}
29-
auto round(vftype a) {
30-
return __builtin_convertvector(lround(a), vftype);
31-
}
32-
u64x4 montgomery_reduce(u64x4 x, u64x4 mod, u64x4 imod) {
33-
auto x_ninv = _mm256_mul_epu32(__m256i(x), __m256i(imod));
34-
auto x_res = _mm256_add_epi64(__m256i(x), _mm256_mul_epu32(x_ninv, __m256i(mod)));
35-
return u64x4(_mm256_bsrli_epi128(x_res, 4));
36-
}
37-
u64x4 montgomery_mul(u64x4 x, u64x4 y, u64x4 mod, u64x4 imod) {
38-
return montgomery_reduce(u64x4(_mm256_mul_epu32(__m256i(x), __m256i(y))), mod, imod);
39-
}
4020

4121
struct cvector {
4222
std::vector<vpoint, big_alloc<vpoint>> r;
@@ -99,8 +79,7 @@ namespace cp_algo::math::fft {
9979
}
10080
template<int step>
10181
static void exec_on_eval(size_t n, size_t k, auto &&callback) {
102-
point factor = root(4 * step * n);
103-
callback(factor * eval_point(step * k));
82+
callback(root(4 * step * n) * eval_point(step * k));
10483
}
10584

10685
void dot(cvector const& t) {

cp-algo/util/simd.hpp

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#ifndef CP_ALGO_UTIL_SIMD_HPP
2+
#define CP_ALGO_UTIL_SIMD_HPP
3+
#include <experimental/simd>
4+
#include <cstdint>
5+
#include <cstddef>
6+
namespace cp_algo {
7+
template<typename T, size_t len>
8+
using simd [[gnu::vector_size(len * sizeof(T))]] = T;
9+
using i64x4 = simd<int64_t, 4>;
10+
using u64x4 = simd<uint64_t, 4>;
11+
using u32x8 = simd<uint32_t, 8>;
12+
using u32x4 = simd<uint32_t, 4>;
13+
14+
template<typename Simd>
15+
Simd abs(Simd a) {
16+
#ifdef __AVX2__
17+
return _mm256_and_pd(a, Simd{} + 1/0.);
18+
#else
19+
return a < 0 ? -a : a;
20+
#endif
21+
}
22+
23+
template<typename Simd>
24+
i64x4 lround(Simd a) {
25+
#ifdef __AVX2__
26+
return __builtin_convertvector(_mm256_round_pd(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC), i64x4);
27+
#else
28+
return __builtin_convertvector(a < 0 ? a - 0.5 : a + 0.5, i64x4);
29+
#endif
30+
}
31+
32+
template<typename Simd>
33+
Simd round(Simd a) {
34+
#ifdef __AVX2__
35+
return _mm256_round_pd(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
36+
#else
37+
return __builtin_convertvector(lround(a), Simd);
38+
#endif
39+
}
40+
41+
u64x4 montgomery_reduce(u64x4 x, u64x4 mod, u64x4 imod) {
42+
#ifdef __AVX2__
43+
auto x_ninv = _mm256_mul_epu32(__m256i(x), __m256i(imod));
44+
auto x_res = _mm256_add_epi64(__m256i(x), _mm256_mul_epu32(x_ninv, __m256i(mod)));
45+
return u64x4(_mm256_bsrli_epi128(x_res, 4));
46+
#else
47+
auto x_ninv = x * imod;
48+
auto x_res = x + ((x_ninv << 32) >> 32) * mod;
49+
return u64x4(x_res >> 32);
50+
#endif
51+
}
52+
53+
u64x4 montgomery_mul(u64x4 x, u64x4 y, u64x4 mod, u64x4 imod) {
54+
#ifdef __AVX2__
55+
return montgomery_reduce(u64x4(_mm256_mul_epu32(__m256i(x), __m256i(y))), mod, imod);
56+
#else
57+
return montgomery_reduce(x * y, mod, imod);
58+
#endif
59+
}
60+
}
61+
#endif // CP_ALGO_UTIL_SIMD_HPP

0 commit comments

Comments
 (0)