Skip to content

Commit 4024f91

Browse files
SlyEchoardforkfunnbotEngininja2
committed
Add intrinsics polyfills for AMD
--------- Co-authored-by: ardfork <134447697+ardfork@users.noreply.github.com> Co-authored-by: funnbot <22226942+funnbot@users.noreply.github.com> Co-authored-by: Engininja2 <139037756+Engininja2@users.noreply.github.com>
1 parent ab62128 commit 4024f91

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ if (LLAMA_HIPBLAS)
379379
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
380380
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
381381
target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
382-
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV)
383382
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
384383
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
385384

Makefile

-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,6 @@ ggml-cuda.o: CXXFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
302302
ggml-cuda.o: CXXFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
303303
ggml-cuda.o: CXXFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
304304
ggml-cuda.o: CXXFLAGS += -DGGML_CUDA_MMQ_Y=$(LLAMA_CUDA_MMQ_Y)
305-
ggml-cuda.o: CXXFLAGS += -DGGML_CUDA_FORCE_DMMV
306305
ggml-cuda.o: CXXFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
307306
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
308307
$(CXX) $(CXXFLAGS) -x hip -c -o $@ $<

ggml-cuda.cu

+38-15
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,29 @@
7575

7676
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
7777

78+
#if defined(GGML_USE_HIPBLAS)
79+
#define __CUDA_ARCH__ 1300
80+
81+
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
82+
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
83+
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
84+
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
85+
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
86+
return reinterpret_cast<const int&>(c);
87+
}
88+
89+
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
90+
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
91+
c = __builtin_amdgcn_sdot4(a, b, c, false);
92+
#else
93+
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
94+
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
95+
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
96+
#endif
97+
return c;
98+
}
99+
#endif
100+
78101
#if defined(_MSC_VER)
79102
#pragma warning(disable: 4244 4267) // possible loss of data
80103
#endif
@@ -1396,8 +1419,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
13961419
return;
13971420
}
13981421

1399-
y[ib].ds.x = d;
1400-
y[ib].ds.y = sum;
1422+
reinterpret_cast<half&>(y[ib].ds.x) = d;
1423+
reinterpret_cast<half&>(y[ib].ds.y) = sum;
14011424
}
14021425

14031426
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -1609,8 +1632,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
16091632
#else
16101633
const float2 dm8f = __half22float2(dm8);
16111634
const float2 ds8f = __half22float2(ds8);
1612-
const float d8d8 = dm8.x * ds8.x;
1613-
const float m8s8 = dm8.y * ds8.y;
1635+
const float d8d8 = __low2float(dm8) * __low2float(ds8);
1636+
const float m8s8 = __high2float(dm8) * __high2float(ds8);
16141637
#endif // GGML_CUDA_F16
16151638

16161639
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
@@ -2380,7 +2403,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
23802403
u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
23812404
}
23822405

2383-
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x);
2406+
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
23842407
}
23852408

23862409
static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -2478,7 +2501,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
24782501
#pragma unroll
24792502
for (int i = 0; i < QR2_K; ++ i) {
24802503
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
2481-
d8[i] = bq8_1[bq8_offset + i].ds.x;
2504+
d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
24822505
}
24832506

24842507
return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
@@ -2605,7 +2628,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
26052628
#pragma unroll
26062629
for (int i = 0; i < QR3_K; ++i) {
26072630
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
2608-
d8[i] = bq8_1[bq8_offset + i].ds.x;
2631+
d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
26092632
}
26102633

26112634
return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
@@ -2782,7 +2805,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
27822805

27832806
for (int i = 0; i < QR4_K; ++i) {
27842807
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
2785-
d8[i] = bq8i->ds.x;
2808+
d8[i] = __low2half(bq8i->ds);
27862809

27872810
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
27882811
u[2*i+0] = q8[0];
@@ -2809,8 +2832,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
28092832
const float dall = bq4_K->d[0];
28102833
const float dmin = bq4_K->d[1];
28112834

2812-
const float d8_1 = bq8_1[0].ds.x;
2813-
const float d8_2 = bq8_1[1].ds.x;
2835+
const float d8_1 = __low2float(bq8_1[0].ds);
2836+
const float d8_2 = __low2float(bq8_1[1].ds);
28142837

28152838
const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
28162839
const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
@@ -2977,7 +3000,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
29773000
#pragma unroll
29783001
for (int i = 0; i < QR5_K; ++i) {
29793002
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
2980-
d8[i] = bq8i->ds.x;
3003+
d8[i] = __low2float(bq8i->ds);
29813004

29823005
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
29833006
u[2*i+0] = q8[0];
@@ -2995,8 +3018,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
29953018

29963019
const float d = bq5_K->d;
29973020

2998-
const float d8_1 = bq8_1[0].ds.x;
2999-
const float d8_2 = bq8_1[1].ds.x;
3021+
const float d8_1 = __low2half(bq8_1[0].ds);
3022+
const float d8_2 = __low2half(bq8_1[1].ds);
30003023

30013024
const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
30023025
const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
@@ -3157,7 +3180,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
31573180
#pragma unroll
31583181
for (int i = 0; i < QR6_K; ++i) {
31593182
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
3160-
d8[i] = bq8_1[bq8_offset + 2*i].ds.x;
3183+
d8[i] = __low2half(bq8_1[bq8_offset + 2*i].ds);
31613184
}
31623185

31633186
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
@@ -3336,7 +3359,7 @@ static __global__ void mul_mat_q(
33363359
*dsi_dst = *dsi_src;
33373360
} else {
33383361
float * dfi_dst = (float *) dsi_dst;
3339-
*dfi_dst = (*dsi_src).x;
3362+
*dfi_dst = __low2half(*dsi_src);
33403363
}
33413364
}
33423365

0 commit comments

Comments
 (0)