@@ -2819,22 +2819,25 @@ kernel void kernel_flash_attn_ext(
2819
2819
float S[Q] = { [0 ... Q-1 ] = 0 .0h };
2820
2820
float M[Q] = { [0 ... Q-1 ] = -FLT_MAX/2 };
2821
2821
2822
+ // thread indices inside the simdgroup
2823
+ const short tx = tiisg%4 ;
2824
+ const short ty = tiisg/4 ;
2825
+
2822
2826
// assume K and V are same shape
2823
2827
const short ne22 = ne12;
2824
2828
const short ne23 = ne13;
2825
2829
2826
- // broadcast
2830
+ // broadcast k
2827
2831
const short rk2 = ne02/ne12;
2828
2832
const short rk3 = ne03/ne13;
2829
2833
2830
- const short rv2 = ne02/ne22;
2831
- const short rv3 = ne03/ne23;
2832
-
2833
- // k indices
2834
2834
const short ik2 = iq2/rk2;
2835
2835
const short ik3 = iq3/rk3;
2836
2836
2837
- // v indices
2837
+ // broadcast v
2838
+ const short rv2 = ne02/ne22;
2839
+ const short rv3 = ne03/ne23;
2840
+
2838
2841
const short iv2 = iq2/rv2;
2839
2842
const short iv3 = iq3/rv3;
2840
2843
@@ -2885,15 +2888,12 @@ kernel void kernel_flash_attn_ext(
2885
2888
}
2886
2889
} else {
2887
2890
for (short ii = 0 ; ii < D16; ii += 4 ) {
2888
- const short i = tiisg%4 ;
2889
- const short j = tiisg/4 ;
2890
-
2891
- device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8 *cc + j)*nb11 + ik2*nb12 + ik3*nb13));
2891
+ device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8 *cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
2892
2892
2893
2893
if (D16%4 == 0 ) {
2894
2894
half4x4 tmp;
2895
- dequantize_func (pk4 + (ii + i )/nl, (ii + i )%nl, tmp);
2896
- skv4[4 *j + i ] = tmp;
2895
+ dequantize_func (pk4 + (ii + tx )/nl, (ii + tx )%nl, tmp);
2896
+ skv4[4 *ty + tx ] = tmp;
2897
2897
2898
2898
simdgroup_barrier (mem_flags::mem_threadgroup);
2899
2899
@@ -2908,10 +2908,10 @@ kernel void kernel_flash_attn_ext(
2908
2908
simdgroup_multiply_accumulate (mqk, mq[2 *(ii + k) + 1 ], mk, mqk);
2909
2909
}
2910
2910
} else {
2911
- if (ii + i < D16) {
2911
+ if (ii + tx < D16) {
2912
2912
half4x4 tmp;
2913
- dequantize_func (pk4 + (ii + i )/nl, (ii + i )%nl, tmp);
2914
- skv4[4 *j + i ] = tmp;
2913
+ dequantize_func (pk4 + (ii + tx )/nl, (ii + tx )%nl, tmp);
2914
+ skv4[4 *ty + tx ] = tmp;
2915
2915
}
2916
2916
2917
2917
simdgroup_barrier (mem_flags::mem_threadgroup);
@@ -3006,15 +3006,12 @@ kernel void kernel_flash_attn_ext(
3006
3006
}
3007
3007
} else {
3008
3008
for (short ii = 0 ; ii < D16; ii += 4 ) {
3009
- const short i = tiisg%4 ;
3010
- const short j = tiisg/4 ;
3011
-
3012
- device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8 *cc + j)*nb21 + iv2*nb22 + iv3*nb23));
3009
+ device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8 *cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
3013
3010
3014
3011
if (D16%4 == 0 ) {
3015
3012
half4x4 tmp;
3016
- dequantize_func (pv4 + (ii + i )/nl, (ii + i )%nl, tmp);
3017
- skv4[4 *j + i ] = tmp;
3013
+ dequantize_func (pv4 + (ii + tx )/nl, (ii + tx )%nl, tmp);
3014
+ skv4[4 *ty + tx ] = tmp;
3018
3015
3019
3016
simdgroup_barrier (mem_flags::mem_threadgroup);
3020
3017
@@ -3029,10 +3026,10 @@ kernel void kernel_flash_attn_ext(
3029
3026
simdgroup_multiply_accumulate (lo[2 *(ii + k) + 1 ], ms, mv, lo[2 *(ii + k) + 1 ]);
3030
3027
}
3031
3028
} else {
3032
- if (ii + i < D16) {
3029
+ if (ii + tx < D16) {
3033
3030
half4x4 tmp;
3034
- dequantize_func (pv4 + (ii + i )/nl, (ii + i )%nl, tmp);
3035
- skv4[4 *j + i ] = tmp;
3031
+ dequantize_func (pv4 + (ii + tx )/nl, (ii + tx )%nl, tmp);
3032
+ skv4[4 *ty + tx ] = tmp;
3036
3033
}
3037
3034
3038
3035
simdgroup_barrier (mem_flags::mem_threadgroup);
@@ -3187,6 +3184,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_
3187
3184
template [[host_name(" kernel_flash_attn_ext_q8_0_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2 , dequantize_q8_0, 128 >;
3188
3185
template [[host_name(" kernel_flash_attn_ext_q8_0_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2 , dequantize_q8_0, 256 >;
3189
3186
3187
+ // NOTE: can use half instead of float precision for some extra perf
3190
3188
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
3191
3189
template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &), short D, short Q = 1 , short C = 32 >
3192
3190
kernel void kernel_flash_attn_ext_vec (
@@ -3239,26 +3237,15 @@ kernel void kernel_flash_attn_ext_vec(
3239
3237
3240
3238
const short T = D + 2 *nsg*SH; // shared memory size per query in (half)
3241
3239
3242
- float slope = 1 .0f ;
3243
-
3244
- // ALiBi
3245
- if (max_bias > 0 .0f ) {
3246
- const uint32_t h = iq2;
3247
-
3248
- const float base = h < n_head_log2 ? m0 : m1;
3249
- const int exp = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
3250
-
3251
- slope = pow (base, exp );
3252
- }
3253
-
3254
- // threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
3255
- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
3256
- threadgroup float * ss = (threadgroup float *) (shared + 2 *sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
3257
- threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2 *sgitg*SH + 1 *D); // same as above but in half4
3258
- threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3240
+ // threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
3241
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0 *D); // same as above but in half4
3242
+ threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0 *D); // same as above but in half4x4
3243
+ threadgroup float * ss = (threadgroup float *) (shared + 2 *sgitg*SH + 1 *D); // scratch buffer for attention and diagonal matrix
3244
+ threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2 *sgitg*SH + 1 *D); // same as above but in half4
3245
+ threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2 *sgitg*D + Q*T); // scratch buffer for the results
3259
3246
3260
3247
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3261
- half4x4 lo[D16/NW4];
3248
+ float4x4 lo[D16/NW4];
3262
3249
3263
3250
// load heads from Q to shared memory
3264
3251
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@@ -3273,7 +3260,7 @@ kernel void kernel_flash_attn_ext_vec(
3273
3260
3274
3261
// zero out lo
3275
3262
for (short i = 0 ; i < D16/NW4; i += NW4) {
3276
- lo[i] = half4x4 (0 .0h);
3263
+ lo[i] = float4x4 (0 .0h);
3277
3264
}
3278
3265
3279
3266
// zero out shared memory SH
@@ -3284,42 +3271,53 @@ kernel void kernel_flash_attn_ext_vec(
3284
3271
threadgroup_barrier (mem_flags::mem_threadgroup);
3285
3272
3286
3273
{
3287
- float S = { 0 .0h };
3288
- float M = { -FLT_MAX/2 };
3274
+ float S = 0 .0h;
3275
+ float M = -FLT_MAX/2 ;
3276
+
3277
+ // thread indices inside the simdgroup
3278
+ const short tx = tiisg%8 ;
3279
+ const short ty = tiisg/8 ;
3289
3280
3290
3281
// assume K and V are same shape
3291
3282
const short ne22 = ne12;
3292
3283
const short ne23 = ne13;
3293
3284
3294
- // broadcast
3285
+ // broadcast k
3295
3286
const short rk2 = ne02/ne12;
3296
3287
const short rk3 = ne03/ne13;
3297
3288
3289
+ const short ik2 = iq2/rk2;
3290
+ const short ik3 = iq3/rk3;
3291
+
3292
+ // broadcast v
3298
3293
const short rv2 = ne02/ne22;
3299
3294
const short rv3 = ne03/ne23;
3300
3295
3301
- // k indices
3302
- const short ik2 = iq2 / rk2;
3303
- const short ik3 = iq3 / rk3;
3304
-
3305
- // v indices
3306
- const short iv2 = iq2 / rv2;
3307
- const short iv3 = iq3 / rv3;
3296
+ const short iv2 = iq2/rv2;
3297
+ const short iv3 = iq3/rv3;
3308
3298
3309
3299
// load the queries from shared memory into local memory
3310
3300
float4x4 mq[D16/NW4];
3311
3301
3312
3302
for (short ii = 0 ; ii < D16; ii += NW4) {
3313
- short i = ii + tiisg%8 ;
3314
- mq[ii/NW4][0 ] = (float4) sq4[4 *i + 0 ];
3315
- mq[ii/NW4][1 ] = (float4) sq4[4 *i + 1 ];
3316
- mq[ii/NW4][2 ] = (float4) sq4[4 *i + 2 ];
3317
- mq[ii/NW4][3 ] = (float4) sq4[4 *i + 3 ];
3303
+ mq[ii/NW4] = (float4x4) sq44[ii + tx];
3318
3304
}
3319
3305
3320
3306
// pointer to the mask
3321
3307
device const half * mp = (device const half *) (mask + iq1*nb31);
3322
3308
3309
+ float slope = 1 .0f ;
3310
+
3311
+ // ALiBi
3312
+ if (max_bias > 0 .0f ) {
3313
+ const uint32_t h = iq2;
3314
+
3315
+ const float base = h < n_head_log2 ? m0 : m1;
3316
+ const int exp = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
3317
+
3318
+ slope = pow (base, exp );
3319
+ }
3320
+
3323
3321
// loop over the KV cache
3324
3322
// each simdgroup handles blocks of Q rows and C columns
3325
3323
for (int ic0 = 0 ; ic0 < ne11; ic0 += C*nsg) {
@@ -3331,18 +3329,16 @@ kernel void kernel_flash_attn_ext_vec(
3331
3329
// Q*K^T
3332
3330
{
3333
3331
// each simdgroup processes 1 query and 4 keys
3334
- const short j = tiisg/8 ;
3335
- #pragma unroll
3336
3332
for (short cc = 0 ; cc < C/4 ; ++cc) {
3337
3333
float mqk = 0.0 ;
3338
3334
3339
- device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4 *cc + j )*nb11 + ik2*nb12 + ik3*nb13));
3335
+ device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4 *cc + ty )*nb11 + ik2*nb12 + ik3*nb13));
3340
3336
3341
- float4x4 mk;
3342
3337
#pragma unroll
3343
3338
for (short ii = 0 ; ii < D16; ii += NW4) {
3344
- const short i = ii + tiisg% 8 ; // 0..7
3339
+ const short i = ii + tx;
3345
3340
3341
+ float4x4 mk;
3346
3342
dequantize_func (pk + i/nl, i%nl, mk);
3347
3343
3348
3344
mqk +=
@@ -3364,16 +3360,16 @@ kernel void kernel_flash_attn_ext_vec(
3364
3360
mqk += simd_shuffle_down (mqk, 1 );
3365
3361
3366
3362
// mqk = mqk*scale + mask*slope
3367
- if (tiisg% 8 == 0 ) {
3363
+ if (tx == 0 ) {
3368
3364
mqk *= scale;
3369
3365
3370
3366
if (logit_softcap != 0 .0f ) {
3371
3367
mqk = logit_softcap*precise::tanh (mqk);
3372
3368
}
3373
3369
3374
- mqk += (mask != q) ? ((float ) mp[ic + 4 *cc + j ])*slope : (float ) 0 .0f ;
3370
+ mqk += (mask != q) ? ((float ) mp[ic + 4 *cc + ty ])*slope : (float ) 0 .0f ;
3375
3371
3376
- ss[4 *cc + j ] = mqk;
3372
+ ss[4 *cc + ty ] = mqk;
3377
3373
}
3378
3374
}
3379
3375
}
@@ -3408,20 +3404,20 @@ kernel void kernel_flash_attn_ext_vec(
3408
3404
3409
3405
// O = O + (Q*K^T)*V
3410
3406
{
3411
- const short j = tiisg/8 ;
3412
3407
#pragma unroll
3413
3408
for (short cc = 0 ; cc < C/4 ; ++cc) {
3414
- device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4 *cc + j)*nb21 + iv2*nb22 + iv3*nb23));
3409
+ device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4 *cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
3410
+
3411
+ const float4x4 lss (ss[4 *cc + ty]);
3415
3412
3416
- float4x4 mv;
3417
- const float4x4 lss (ss[4 *cc + j]);
3418
3413
#pragma unroll
3419
3414
for (short ii = 0 ; ii < D16; ii += NW4) {
3420
- const short i = ii + tiisg% 8 ;
3415
+ const short i = ii + tx ;
3421
3416
3417
+ float4x4 mv;
3422
3418
dequantize_func (pv4 + i/nl, i%nl, mv);
3423
3419
3424
- lo[ii/NW4] += (half4x4)( mv*lss) ;
3420
+ lo[ii/NW4] += mv*lss;
3425
3421
}
3426
3422
}
3427
3423
}
@@ -3458,14 +3454,8 @@ kernel void kernel_flash_attn_ext_vec(
3458
3454
}
3459
3455
3460
3456
// store results to shared memory
3461
- for (short ii = 0 ; ii < D16; ii += NW4) {
3462
- short i = ii + tiisg;
3463
- if (tiisg < 8 ) {
3464
- sr4[4 *i + 0 ] = lo[ii/NW4][0 ];
3465
- sr4[4 *i + 1 ] = lo[ii/NW4][1 ];
3466
- sr4[4 *i + 2 ] = lo[ii/NW4][2 ];
3467
- sr4[4 *i + 3 ] = lo[ii/NW4][3 ];
3468
- }
3457
+ for (short i = tiisg; i < D16; i += NW4) {
3458
+ sr44[i] = lo[i/NW4];
3469
3459
}
3470
3460
3471
3461
threadgroup_barrier (mem_flags::mem_threadgroup);
@@ -3492,24 +3482,22 @@ kernel void kernel_flash_attn_ext_vec(
3492
3482
}
3493
3483
3494
3484
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3495
- for (short ii = 0 ; ii < D4; ii += NW) {
3496
- short i = ii + tiisg;
3497
- sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
3485
+ for (short i = tiisg; i < D16; i += NW) {
3486
+ sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
3498
3487
}
3499
3488
}
3500
3489
3501
3490
threadgroup_barrier (mem_flags::mem_threadgroup);
3502
3491
}
3503
3492
3504
- device float4 * dst4 = (device float4 *) dst;
3493
+ device float4x4 * dst44 = (device float4x4 *) dst;
3505
3494
3506
3495
// final rescale with 1/S and store to global memory
3507
3496
if (sgitg == 0 ) {
3508
3497
const float S = ss[0 ];
3509
3498
3510
- for (short ii = 0 ; ii < D4; ii += NW) {
3511
- short i = ii + tiisg;
3512
- dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
3499
+ for (short i = tiisg; i < D16; i += NW) {
3500
+ dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
3513
3501
}
3514
3502
}
3515
3503
}
0 commit comments