Skip to content

Commit 14b1d7e

Browse files
authored
metal : add missing barriers for mul-mat (#2699)
1 parent 226255b commit 14b1d7e

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

ggml-metal.metal

+3-2
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
18501850
//load data and store to threadgroup memory
18511851
half4x4 temp_a;
18521852
dequantize_func(x, il, temp_a);
1853+
threadgroup_barrier(mem_flags::mem_threadgroup);
18531854
#pragma unroll(16)
18541855
for (int i = 0; i < 16; i++) {
18551856
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
@@ -1895,14 +1896,14 @@ kernel void kernel_mul_mm(device const uchar * src0,
18951896
}
18961897
} else {
18971898
// block is smaller than 64x32, we should avoid writing data outside of the matrix
1899+
threadgroup_barrier(mem_flags::mem_threadgroup);
18981900
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
18991901
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
19001902
for (int i = 0; i < 8; i++) {
1901-
threadgroup_barrier(mem_flags::mem_device);
19021903
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
19031904
}
19041905

1905-
threadgroup_barrier(mem_flags::mem_device);
1906+
threadgroup_barrier(mem_flags::mem_threadgroup);
19061907
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
19071908
if (sgitg==0) {
19081909
for (int i = 0; i < n_rows; i++) {

0 commit comments

Comments
 (0)