Skip to content

Commit f66d362

Browse files
committed
metal : use F16 precision in FA kernels
ggml-ci
1 parent 22a9311 commit f66d362

File tree

7 files changed

+482
-339
lines changed

7 files changed

+482
-339
lines changed

Makefile

+5
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,11 @@ endif # GGML_HIPBLAS
876876

877877
ifdef GGML_METAL
878878
MK_CPPFLAGS += -DGGML_USE_METAL
879+
880+
ifdef GGML_METAL_FORCE_FATTN_PREC_F16
881+
MK_CPPFLAGS += -DGGML_METAL_FORCE_FATTN_PREC_F16
882+
endif # GGML_METAL_FORCE_FATTN_PREC_F16
883+
879884
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
880885
OBJ_GGML += ggml/src/ggml-metal.o
881886
ifdef GGML_METAL_NDEBUG

examples/llama-bench/llama-bench.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
256256
if (s == "f16") {
257257
return GGML_TYPE_F16;
258258
}
259+
if (s == "bf16") {
260+
return GGML_TYPE_BF16;
261+
}
259262
if (s == "q8_0") {
260263
return GGML_TYPE_Q8_0;
261264
}

ggml/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
153153
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
154154
option(GGML_KOMPUTE "ggml: use Kompute" OFF)
155155
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
156+
option(GGML_METAL_FORCE_FATTN_PREC_F16 "ggml: force F16 accumulators for FA kernels" OFF)
156157
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
157158
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
158159
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})

ggml/src/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ if (GGML_METAL)
5858
add_compile_definitions(GGML_METAL_NDEBUG)
5959
endif()
6060

61+
if (GGML_METAL_FORCE_FATTN_PREC_F16)
62+
add_compile_definitions(GGML_METAL_FORCE_FATTN_PREC_F16)
63+
endif()
64+
6165
# copy ggml-common.h and ggml-metal.metal to bin directory
6266
configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
6367
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)

ggml/src/ggml-metal.m

+55-17
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
269269
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
270270
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
271271
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
272+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
273+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
274+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
275+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
276+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
277+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
272278
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
273279
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
274280
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -300,12 +306,14 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
300306
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
301307
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
302308
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
309+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
303310
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
304311
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
305312
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
306313
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
307314
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
308315
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
316+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
309317
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
310318
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
311319
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
@@ -585,6 +593,9 @@ @implementation GGMLMetalClass
585593
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
586594
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
587595
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
596+
GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
597+
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
598+
(int) kernel->pipeline.threadExecutionWidth); \
588599
[metal_function release]; \
589600
if (error) { \
590601
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@@ -777,6 +788,12 @@ @implementation GGMLMetalClass
777788
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
778789
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
779790
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
791+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm);
792+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm);
793+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm);
794+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm);
795+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm);
796+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm);
780797
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
781798
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
782799
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -808,12 +825,14 @@ @implementation GGMLMetalClass
808825
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
809826
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
810827
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
828+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction);
811829
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
812830
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
813831
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
814832
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
815833
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
816834
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
835+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction);
817836
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
818837
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
819838
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
@@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
11111130
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
11121131
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
11131132
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1114-
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
1133+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
11151134

11161135
const int64_t ne0 = dst ? dst->ne[0] : 0;
11171136
const int64_t ne1 = dst ? dst->ne[1] : 0;
@@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
30333052
}
30343053
}
30353054
} break;
3055+
case GGML_TYPE_BF16:
3056+
{
3057+
switch (ne00) {
3058+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
3059+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
3060+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
3061+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
3062+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
3063+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
3064+
default:
3065+
{
3066+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3067+
GGML_LOG_ERROR("add template specialization for this size\n");
3068+
GGML_ABORT("add template specialization for this size");
3069+
}
3070+
}
3071+
} break;
30363072
case GGML_TYPE_Q4_0:
30373073
{
30383074
switch (ne00) {
@@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
31333169
{
31343170
switch (src1->type) {
31353171
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
3172+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
31363173
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
31373174
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
31383175
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
@@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
31503187
{
31513188
switch (src1->type) {
31523189
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
3190+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
31533191
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
31543192
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
31553193
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
@@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
31943232
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
31953233
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
31963234
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
3197-
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
3198-
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
3199-
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
3200-
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
3201-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
3202-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
3203-
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
3204-
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
3205-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
3206-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
3207-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
3208-
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
3235+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
3236+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
3237+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
3238+
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
3239+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
3240+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
3241+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
3242+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
3243+
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
32093244

32103245
if (!use_vec_kernel) {
32113246
// half8x8 kernel
@@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
32163251
GGML_ASSERT(nqptg % 8 == 0);
32173252
GGML_ASSERT(ncpsg % 32 == 0);
32183253

3254+
// 2*(2*ncpsg + nqptg)*(nsg)
3255+
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
3256+
//
32193257
// 16*32*(nsg)
32203258
// the shared memory needed for the simdgroups to load the KV cache
32213259
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
32223260
//
3223-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
3261+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
32243262

32253263
int64_t nsgmax = 2;
32263264

@@ -3256,10 +3294,10 @@ static void ggml_metal_encode_node(
32563294
// for each query, we load it as f16 in shared memory (ne00)
32573295
// and store the attention scores (nqptg x ncpsg) as f32
32583296
//
3259-
// 2*ne00*(nsg)
3260-
// each simdgroup has a full f32 head vector in shared mem to accumulate results
3297+
// ne00*(nsg)
3298+
// each simdgroup has a full f16 head vector in shared mem to accumulate results
32613299
//
3262-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
3300+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 4*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
32633301

32643302
int64_t nsgmax = 2;
32653303

0 commit comments

Comments
 (0)