Skip to content

Commit 1e12961

Browse files
committed
metal : clean-up (cont)
1 parent dd0d9ed commit 1e12961

File tree

2 files changed

+76
-88
lines changed

2 files changed

+76
-88
lines changed

ggml/src/ggml-metal.m

+1-1
Original file line numberDiff line numberDiff line change
@@ -3187,7 +3187,7 @@ static void ggml_metal_encode_node(
31873187
}
31883188
nsg /= 2;
31893189

3190-
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
3190+
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2);
31913191

31923192
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
31933193
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);

ggml/src/ggml-metal.metal

+75-87
Original file line numberDiff line numberDiff line change
@@ -2819,22 +2819,25 @@ kernel void kernel_flash_attn_ext(
28192819
float S[Q] = { [0 ... Q-1] = 0.0h };
28202820
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
28212821

2822+
// thread indices inside the simdgroup
2823+
const short tx = tiisg%4;
2824+
const short ty = tiisg/4;
2825+
28222826
// assume K and V are same shape
28232827
const short ne22 = ne12;
28242828
const short ne23 = ne13;
28252829

2826-
// broadcast
2830+
// broadcast k
28272831
const short rk2 = ne02/ne12;
28282832
const short rk3 = ne03/ne13;
28292833

2830-
const short rv2 = ne02/ne22;
2831-
const short rv3 = ne03/ne23;
2832-
2833-
// k indices
28342834
const short ik2 = iq2/rk2;
28352835
const short ik3 = iq3/rk3;
28362836

2837-
// v indices
2837+
// broadcast v
2838+
const short rv2 = ne02/ne22;
2839+
const short rv3 = ne03/ne23;
2840+
28382841
const short iv2 = iq2/rv2;
28392842
const short iv3 = iq3/rv3;
28402843

@@ -2885,15 +2888,12 @@ kernel void kernel_flash_attn_ext(
28852888
}
28862889
} else {
28872890
for (short ii = 0; ii < D16; ii += 4) {
2888-
const short i = tiisg%4;
2889-
const short j = tiisg/4;
2890-
2891-
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + j)*nb11 + ik2*nb12 + ik3*nb13));
2891+
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
28922892

28932893
if (D16%4 == 0) {
28942894
half4x4 tmp;
2895-
dequantize_func(pk4 + (ii + i)/nl, (ii + i)%nl, tmp);
2896-
skv4[4*j + i] = tmp;
2895+
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
2896+
skv4[4*ty + tx] = tmp;
28972897

28982898
simdgroup_barrier(mem_flags::mem_threadgroup);
28992899

@@ -2908,10 +2908,10 @@ kernel void kernel_flash_attn_ext(
29082908
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
29092909
}
29102910
} else {
2911-
if (ii + i < D16) {
2911+
if (ii + tx < D16) {
29122912
half4x4 tmp;
2913-
dequantize_func(pk4 + (ii + i)/nl, (ii + i)%nl, tmp);
2914-
skv4[4*j + i] = tmp;
2913+
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
2914+
skv4[4*ty + tx] = tmp;
29152915
}
29162916

29172917
simdgroup_barrier(mem_flags::mem_threadgroup);
@@ -3006,15 +3006,12 @@ kernel void kernel_flash_attn_ext(
30063006
}
30073007
} else {
30083008
for (short ii = 0; ii < D16; ii += 4) {
3009-
const short i = tiisg%4;
3010-
const short j = tiisg/4;
3011-
3012-
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + j)*nb21 + iv2*nb22 + iv3*nb23));
3009+
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
30133010

30143011
if (D16%4 == 0) {
30153012
half4x4 tmp;
3016-
dequantize_func(pv4 + (ii + i)/nl, (ii + i)%nl, tmp);
3017-
skv4[4*j + i] = tmp;
3013+
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
3014+
skv4[4*ty + tx] = tmp;
30183015

30193016
simdgroup_barrier(mem_flags::mem_threadgroup);
30203017

@@ -3029,10 +3026,10 @@ kernel void kernel_flash_attn_ext(
30293026
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
30303027
}
30313028
} else {
3032-
if (ii + i < D16) {
3029+
if (ii + tx < D16) {
30333030
half4x4 tmp;
3034-
dequantize_func(pv4 + (ii + i)/nl, (ii + i)%nl, tmp);
3035-
skv4[4*j + i] = tmp;
3031+
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
3032+
skv4[4*ty + tx] = tmp;
30363033
}
30373034

30383035
simdgroup_barrier(mem_flags::mem_threadgroup);
@@ -3187,6 +3184,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_
31873184
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
31883185
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
31893186

3187+
// NOTE: can use half instead of float precision for some extra perf
31903188
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
31913189
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
31923190
kernel void kernel_flash_attn_ext_vec(
@@ -3239,26 +3237,15 @@ kernel void kernel_flash_attn_ext_vec(
32393237

32403238
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
32413239

3242-
float slope = 1.0f;
3243-
3244-
// ALiBi
3245-
if (max_bias > 0.0f) {
3246-
const uint32_t h = iq2;
3247-
3248-
const float base = h < n_head_log2 ? m0 : m1;
3249-
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
3250-
3251-
slope = pow(base, exp);
3252-
}
3253-
3254-
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
3255-
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
3256-
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
3257-
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
3258-
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3240+
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
3241+
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
3242+
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
3243+
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
3244+
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
3245+
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
32593246

32603247
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3261-
half4x4 lo[D16/NW4];
3248+
float4x4 lo[D16/NW4];
32623249

32633250
// load heads from Q to shared memory
32643251
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@@ -3273,7 +3260,7 @@ kernel void kernel_flash_attn_ext_vec(
32733260

32743261
// zero out lo
32753262
for (short i = 0; i < D16/NW4; i += NW4) {
3276-
lo[i] = half4x4(0.0h);
3263+
lo[i] = float4x4(0.0h);
32773264
}
32783265

32793266
// zero out shared memory SH
@@ -3284,42 +3271,53 @@ kernel void kernel_flash_attn_ext_vec(
32843271
threadgroup_barrier(mem_flags::mem_threadgroup);
32853272

32863273
{
3287-
float S = { 0.0h };
3288-
float M = { -FLT_MAX/2 };
3274+
float S = 0.0h;
3275+
float M = -FLT_MAX/2;
3276+
3277+
// thread indices inside the simdgroup
3278+
const short tx = tiisg%8;
3279+
const short ty = tiisg/8;
32893280

32903281
// assume K and V are same shape
32913282
const short ne22 = ne12;
32923283
const short ne23 = ne13;
32933284

3294-
// broadcast
3285+
// broadcast k
32953286
const short rk2 = ne02/ne12;
32963287
const short rk3 = ne03/ne13;
32973288

3289+
const short ik2 = iq2/rk2;
3290+
const short ik3 = iq3/rk3;
3291+
3292+
// broadcast v
32983293
const short rv2 = ne02/ne22;
32993294
const short rv3 = ne03/ne23;
33003295

3301-
// k indices
3302-
const short ik2 = iq2 / rk2;
3303-
const short ik3 = iq3 / rk3;
3304-
3305-
// v indices
3306-
const short iv2 = iq2 / rv2;
3307-
const short iv3 = iq3 / rv3;
3296+
const short iv2 = iq2/rv2;
3297+
const short iv3 = iq3/rv3;
33083298

33093299
// load the queries from shared memory into local memory
33103300
float4x4 mq[D16/NW4];
33113301

33123302
for (short ii = 0; ii < D16; ii += NW4) {
3313-
short i = ii + tiisg%8;
3314-
mq[ii/NW4][0] = (float4) sq4[4*i + 0];
3315-
mq[ii/NW4][1] = (float4) sq4[4*i + 1];
3316-
mq[ii/NW4][2] = (float4) sq4[4*i + 2];
3317-
mq[ii/NW4][3] = (float4) sq4[4*i + 3];
3303+
mq[ii/NW4] = (float4x4) sq44[ii + tx];
33183304
}
33193305

33203306
// pointer to the mask
33213307
device const half * mp = (device const half *) (mask + iq1*nb31);
33223308

3309+
float slope = 1.0f;
3310+
3311+
// ALiBi
3312+
if (max_bias > 0.0f) {
3313+
const uint32_t h = iq2;
3314+
3315+
const float base = h < n_head_log2 ? m0 : m1;
3316+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
3317+
3318+
slope = pow(base, exp);
3319+
}
3320+
33233321
// loop over the KV cache
33243322
// each simdgroup handles blocks of Q rows and C columns
33253323
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@@ -3331,18 +3329,16 @@ kernel void kernel_flash_attn_ext_vec(
33313329
// Q*K^T
33323330
{
33333331
// each simdgroup processes 1 query and 4 keys
3334-
const short j = tiisg/8;
3335-
#pragma unroll
33363332
for (short cc = 0; cc < C/4; ++cc) {
33373333
float mqk = 0.0;
33383334

3339-
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + j)*nb11 + ik2*nb12 + ik3*nb13));
3335+
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
33403336

3341-
float4x4 mk;
33423337
#pragma unroll
33433338
for (short ii = 0; ii < D16; ii += NW4) {
3344-
const short i = ii + tiisg%8; // 0..7
3339+
const short i = ii + tx;
33453340

3341+
float4x4 mk;
33463342
dequantize_func(pk + i/nl, i%nl, mk);
33473343

33483344
mqk +=
@@ -3364,16 +3360,16 @@ kernel void kernel_flash_attn_ext_vec(
33643360
mqk += simd_shuffle_down(mqk, 1);
33653361

33663362
// mqk = mqk*scale + mask*slope
3367-
if (tiisg%8 == 0) {
3363+
if (tx == 0) {
33683364
mqk *= scale;
33693365

33703366
if (logit_softcap != 0.0f) {
33713367
mqk = logit_softcap*precise::tanh(mqk);
33723368
}
33733369

3374-
mqk += (mask != q) ? ((float) mp[ic + 4*cc + j])*slope : (float) 0.0f;
3370+
mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
33753371

3376-
ss[4*cc + j] = mqk;
3372+
ss[4*cc + ty] = mqk;
33773373
}
33783374
}
33793375
}
@@ -3408,20 +3404,20 @@ kernel void kernel_flash_attn_ext_vec(
34083404

34093405
// O = O + (Q*K^T)*V
34103406
{
3411-
const short j = tiisg/8;
34123407
#pragma unroll
34133408
for (short cc = 0; cc < C/4; ++cc) {
3414-
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + j)*nb21 + iv2*nb22 + iv3*nb23));
3409+
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
3410+
3411+
const float4x4 lss(ss[4*cc + ty]);
34153412

3416-
float4x4 mv;
3417-
const float4x4 lss(ss[4*cc + j]);
34183413
#pragma unroll
34193414
for (short ii = 0; ii < D16; ii += NW4) {
3420-
const short i = ii + tiisg%8;
3415+
const short i = ii + tx;
34213416

3417+
float4x4 mv;
34223418
dequantize_func(pv4 + i/nl, i%nl, mv);
34233419

3424-
lo[ii/NW4] += (half4x4)(mv*lss);
3420+
lo[ii/NW4] += mv*lss;
34253421
}
34263422
}
34273423
}
@@ -3458,14 +3454,8 @@ kernel void kernel_flash_attn_ext_vec(
34583454
}
34593455

34603456
// store results to shared memory
3461-
for (short ii = 0; ii < D16; ii += NW4) {
3462-
short i = ii + tiisg;
3463-
if (tiisg < 8) {
3464-
sr4[4*i + 0] = lo[ii/NW4][0];
3465-
sr4[4*i + 1] = lo[ii/NW4][1];
3466-
sr4[4*i + 2] = lo[ii/NW4][2];
3467-
sr4[4*i + 3] = lo[ii/NW4][3];
3468-
}
3457+
for (short i = tiisg; i < D16; i += NW4) {
3458+
sr44[i] = lo[i/NW4];
34693459
}
34703460

34713461
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3492,24 +3482,22 @@ kernel void kernel_flash_attn_ext_vec(
34923482
}
34933483

34943484
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3495-
for (short ii = 0; ii < D4; ii += NW) {
3496-
short i = ii + tiisg;
3497-
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
3485+
for (short i = tiisg; i < D16; i += NW) {
3486+
sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
34983487
}
34993488
}
35003489

35013490
threadgroup_barrier(mem_flags::mem_threadgroup);
35023491
}
35033492

3504-
device float4 * dst4 = (device float4 *) dst;
3493+
device float4x4 * dst44 = (device float4x4 *) dst;
35053494

35063495
// final rescale with 1/S and store to global memory
35073496
if (sgitg == 0) {
35083497
const float S = ss[0];
35093498

3510-
for (short ii = 0; ii < D4; ii += NW) {
3511-
short i = ii + tiisg;
3512-
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
3499+
for (short i = tiisg; i < D16; i += NW) {
3500+
dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
35133501
}
35143502
}
35153503
}

0 commit comments

Comments
 (0)