@@ -931,6 +931,101 @@ inline static float vaddvq_f32(float32x4_t v) {
931
931
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
932
932
#endif
933
933
934
+ #elif defined(__AVX512F__)
935
+
936
+ #define GGML_SIMD
937
+
938
+ // F32 AVX512
939
+
940
+ #define GGML_F32_STEP 64
941
+ #define GGML_F32_EPR 16
942
+
943
+ #define GGML_F32x16 __m512
944
+ #define GGML_F32x16_ZERO _mm512_setzero_ps()
945
+ #define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
946
+ #define GGML_F32x16_LOAD _mm512_loadu_ps
947
+ #define GGML_F32x16_STORE _mm512_storeu_ps
948
+ // _mm512_fmadd_ps is defined in AVX512F so no guard is required
949
+ #define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
950
+ #define GGML_F32x16_ADD _mm512_add_ps
951
+ #define GGML_F32x16_MUL _mm512_mul_ps
952
+ #define GGML_F32x16_REDUCE(res, x) \
953
+ do { \
954
+ int offset = GGML_F32_ARR >> 1; \
955
+ for (int i = 0; i < offset; ++i) { \
956
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
957
+ } \
958
+ offset >>= 1; \
959
+ for (int i = 0; i < offset; ++i) { \
960
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
961
+ } \
962
+ offset >>= 1; \
963
+ for (int i = 0; i < offset; ++i) { \
964
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
965
+ } \
966
+ res = _mm512_reduce_add_ps(x[0]); \
967
+ } while (0)
968
+
969
+ // TODO: is this optimal ?
970
+
971
+ #define GGML_F32_VEC GGML_F32x16
972
+ #define GGML_F32_VEC_ZERO GGML_F32x16_ZERO
973
+ #define GGML_F32_VEC_SET1 GGML_F32x16_SET1
974
+ #define GGML_F32_VEC_LOAD GGML_F32x16_LOAD
975
+ #define GGML_F32_VEC_STORE GGML_F32x16_STORE
976
+ #define GGML_F32_VEC_FMA GGML_F32x16_FMA
977
+ #define GGML_F32_VEC_ADD GGML_F32x16_ADD
978
+ #define GGML_F32_VEC_MUL GGML_F32x16_MUL
979
+ #define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
980
+
981
+ // F16 AVX512
982
+
983
+ // F16 AVX
984
+
985
+ #define GGML_F16_STEP 64
986
+ #define GGML_F16_EPR 16
987
+
988
+ // AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
989
+
990
+ #define GGML_F32Cx16 __m512
991
+ #define GGML_F32Cx16_ZERO _mm512_setzero_ps()
992
+ #define GGML_F32Cx16_SET1(x) _mm512_set1_ps(x)
993
+
994
+ // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
995
+ // so F16C guard isn't required
996
+ #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
997
+ #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
998
+
999
+ #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
1000
+ #define GGML_F32Cx16_ADD _mm512_add_ps
1001
+ #define GGML_F32Cx16_MUL _mm512_mul_ps
1002
+ #define GGML_F32Cx16_REDUCE(res, x) \
1003
+ do { \
1004
+ int offset = GGML_F32_ARR >> 1; \
1005
+ for (int i = 0; i < offset; ++i) { \
1006
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
1007
+ } \
1008
+ offset >>= 1; \
1009
+ for (int i = 0; i < offset; ++i) { \
1010
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
1011
+ } \
1012
+ offset >>= 1; \
1013
+ for (int i = 0; i < offset; ++i) { \
1014
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
1015
+ } \
1016
+ res = _mm512_reduce_add_ps(x[0]); \
1017
+ } while (0)
1018
+
1019
+ #define GGML_F16_VEC GGML_F32Cx16
1020
+ #define GGML_F16_VEC_ZERO GGML_F32Cx16_ZERO
1021
+ #define GGML_F16_VEC_SET1 GGML_F32Cx16_SET1
1022
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx16_LOAD(p)
1023
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
1024
+ #define GGML_F16_VEC_FMA GGML_F32Cx16_FMA
1025
+ #define GGML_F16_VEC_ADD GGML_F32Cx16_ADD
1026
+ #define GGML_F16_VEC_MUL GGML_F32Cx16_MUL
1027
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE
1028
+
934
1029
#elif defined(__AVX__)
935
1030
936
1031
#define GGML_SIMD
0 commit comments