Skip to content

Commit a797e5d

Browse files
committed
metal : add GGML_METAL_FORCE_FATTN_PREC_F16
ggml-ci
1 parent d0cff71 commit a797e5d

File tree

5 files changed

+119
-87
lines changed

5 files changed

+119
-87
lines changed

Makefile

+5
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,11 @@ endif # GGML_HIPBLAS
876876

877877
ifdef GGML_METAL
878878
MK_CPPFLAGS += -DGGML_USE_METAL
879+
880+
ifdef GGML_METAL_FORCE_FATTN_PREC_F16
881+
MK_CPPFLAGS += -DGGML_METAL_FORCE_FATTN_PREC_F16
882+
endif # GGML_METAL_FORCE_FATTN_PREC_F16
883+
879884
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
880885
OBJ_GGML += ggml/src/ggml-metal.o
881886
ifdef GGML_METAL_NDEBUG

ggml/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
153153
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
154154
option(GGML_KOMPUTE "ggml: use Kompute" OFF)
155155
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
156+
option(GGML_METAL_FORCE_FATTN_PREC_F16 "ggml: force F16 accumulators for FA kernels" OFF)
156157
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
157158
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
158159
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})

ggml/src/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ if (GGML_METAL)
5858
add_compile_definitions(GGML_METAL_NDEBUG)
5959
endif()
6060

61+
if (GGML_METAL_FORCE_FATTN_PREC_F16)
62+
add_compile_definitions(GGML_METAL_FORCE_FATTN_PREC_F16)
63+
endif()
64+
6165
# copy ggml-common.h and ggml-metal.metal to bin directory
6266
configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
6367
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)

ggml/src/ggml-metal.m

+14-11
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
#define MIN(a, b) ((a) < (b) ? (a) : (b))
1313
#define MAX(a, b) ((a) > (b) ? (a) : (b))
1414

15-
// TODO: for now, always use F32 for flash attention to avoid compiling 2 sets of kernels
16-
#define GGML_METAL_FORCE_FATTN_PREC_F32
17-
1815
// max memory buffers that can be mapped to the device
1916
#define GGML_METAL_MAX_BUFFERS 64
2017

@@ -483,9 +480,8 @@ @implementation GGMLMetalClass
483480
// dictionary of preprocessor macros
484481
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
485482

486-
// add GGML_METAL_FORCE_FATTN_PREC_F32
487-
#if defined(GGML_METAL_FORCE_FATTN_PREC_F32)
488-
[prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F32"];
483+
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
484+
[prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F16"];
489485
#endif
490486

491487
MTLCompileOptions* options = [MTLCompileOptions new];
@@ -538,9 +534,14 @@ @implementation GGMLMetalClass
538534
}
539535
}
540536

541-
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
542-
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
543-
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
537+
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
538+
GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16 = yes\n", __func__);
539+
#else
540+
GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16 = no\n", __func__);
541+
#endif
542+
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
543+
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
544+
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
544545

545546
ctx->capture_next_compute = false;
546547
ctx->capture_started = false;
@@ -3153,10 +3154,12 @@ static void ggml_metal_encode_node(
31533154
GGML_ASSERT(nqptg % 8 == 0);
31543155
GGML_ASSERT(ncpsg % 32 == 0);
31553156

3156-
#ifdef GGML_METAL_FORCE_FATTN_PREC_F32
3157+
#ifdef GGML_METAL_FORCE_FATTN_PREC_F16
31573158
const enum ggml_prec prec = GGML_PREC_DEFAULT;
31583159
#else
3159-
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst);
3160+
// TODO: support both precisions
3161+
const enum ggml_prec prec = GGML_PREC_F32;
3162+
//const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst);
31603163
#endif
31613164

31623165
const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2;

0 commit comments

Comments
 (0)