@@ -3145,12 +3145,16 @@ static void ggml_metal_encode_node(
3145
3145
GGML_ASSERT(nqptg % 8 == 0);
3146
3146
GGML_ASSERT(ncpsg % 32 == 0);
3147
3147
3148
+ // 16*32*nsgmax
3149
+ // the shared memory needed for the simdgroups to load the KV cache
3150
+ // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3151
+ //
3152
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
3153
+
3148
3154
int64_t nsgmax = 2;
3149
3155
3150
3156
while (true) {
3151
- // 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache
3152
- // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3153
- const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2);
3157
+ const size_t smem = FATTN_SMEM(nsgmax);
3154
3158
if (smem > device.maxThreadgroupMemoryLength) {
3155
3159
break;
3156
3160
}
@@ -3161,13 +3165,12 @@ static void ggml_metal_encode_node(
3161
3165
// simdgroups per threadgroup (a.k.a. warps)
3162
3166
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
3163
3167
3164
- const size_t smem = (nqptg*(ne00 + 2* nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2 );
3168
+ const size_t smem = FATTN_SMEM( nsg);
3165
3169
3166
3170
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
3167
3171
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
3168
-
3169
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
3170
-
3172
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
3173
+ #undef FATTN_SMEM
3171
3174
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
3172
3175
} else {
3173
3176
// half4x4 kernel
@@ -3178,21 +3181,41 @@ static void ggml_metal_encode_node(
3178
3181
GGML_ASSERT(nqptg % 1 == 0);
3179
3182
GGML_ASSERT(ncpsg % 32 == 0);
3180
3183
3184
+ // ne00 + 2*ncpsg*(nsg)
3185
+ // for each query, we load it as f16 in shared memory (ne00)
3186
+ // and store the attention scores (nqptg x ncpsg) as f32
3187
+ //
3188
+ // 2*ne00*(nsg)
3189
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
3190
+ //
3191
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
3192
+
3193
+ int64_t nsgmax = 2;
3194
+
3195
+ while (true) {
3196
+ const size_t smem = FATTN_SMEM(nsgmax);
3197
+ if (smem > device.maxThreadgroupMemoryLength) {
3198
+ break;
3199
+ }
3200
+ nsgmax *= 2;
3201
+ }
3202
+ nsgmax /= 2;
3203
+
3181
3204
// simdgroups per threadgroup (a.k.a. warps)
3182
- const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
3205
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN( ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32) ));
3183
3206
3184
3207
int64_t nsg = 1;
3185
3208
while (nsg <= nsgt) {
3186
3209
nsg *= 2;
3187
3210
}
3188
3211
nsg /= 2;
3189
3212
3190
- const size_t smem = (nqptg*(ne00 + 2* nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2 );
3213
+ const size_t smem = FATTN_SMEM( nsg);
3191
3214
3192
- //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
3215
+ //printf("smem: %zu, max: %zu, nsg = %d \n", smem, device.maxThreadgroupMemoryLength, (int) nsg );
3193
3216
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
3194
- [encoder setThreadgroupMemoryLength:GGML_PAD( smem, 16) atIndex:0];
3195
-
3217
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
3218
+ #undef FATTN_SMEM
3196
3219
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
3197
3220
}
3198
3221
} break;
0 commit comments