File tree 1 file changed +3
-2
lines changed
1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -1850,6 +1850,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1850
1850
// load data and store to threadgroup memory
1851
1851
half4x4 temp_a;
1852
1852
dequantize_func (x, il, temp_a);
1853
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1853
1854
#pragma unroll(16)
1854
1855
for (int i = 0 ; i < 16 ; i++) {
1855
1856
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8 ) \
@@ -1895,14 +1896,14 @@ kernel void kernel_mul_mm(device const uchar * src0,
1895
1896
}
1896
1897
} else {
1897
1898
// block is smaller than 64x32, we should avoid writing data outside of the matrix
1899
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1898
1900
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
1899
1901
+ 32 * (sgitg&1 ) + (16 * (sgitg>>1 )) * BLOCK_SIZE_M;
1900
1902
for (int i = 0 ; i < 8 ; i++) {
1901
- threadgroup_barrier (mem_flags::mem_device);
1902
1903
simdgroup_store (c_res[i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
1903
1904
}
1904
1905
1905
- threadgroup_barrier (mem_flags::mem_device );
1906
+ threadgroup_barrier (mem_flags::mem_threadgroup );
1906
1907
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
1907
1908
if (sgitg==0 ) {
1908
1909
for (int i = 0 ; i < n_rows; i++) {
You can’t perform that action at this time.
0 commit comments