|
12 | 12 | #define MIN(a, b) ((a) < (b) ? (a) : (b))
|
13 | 13 | #define MAX(a, b) ((a) > (b) ? (a) : (b))
|
14 | 14 |
|
15 |
| -// TODO: for now, always use F32 for flash attention to avoid compiling 2 sets of kernels |
16 |
| -#define GGML_METAL_FORCE_FATTN_PREC_F32 |
17 |
| - |
18 | 15 | // max memory buffers that can be mapped to the device
|
19 | 16 | #define GGML_METAL_MAX_BUFFERS 64
|
20 | 17 |
|
@@ -483,9 +480,8 @@ @implementation GGMLMetalClass
|
483 | 480 | // dictionary of preprocessor macros
|
484 | 481 | NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
485 | 482 |
|
486 |
| - // add GGML_METAL_FORCE_FATTN_PREC_F32 |
487 |
| -#if defined(GGML_METAL_FORCE_FATTN_PREC_F32) |
488 |
| - [prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F32"]; |
| 483 | +#if defined(GGML_METAL_FORCE_FATTN_PREC_F16) |
| 484 | + [prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F16"]; |
489 | 485 | #endif
|
490 | 486 |
|
491 | 487 | MTLCompileOptions* options = [MTLCompileOptions new];
|
@@ -538,9 +534,14 @@ @implementation GGMLMetalClass
|
538 | 534 | }
|
539 | 535 | }
|
540 | 536 |
|
541 |
| - GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false"); |
542 |
| - GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false"); |
543 |
| - GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); |
| 537 | +#if defined(GGML_METAL_FORCE_FATTN_PREC_F16) |
| 538 | + GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16 = yes\n", __func__); |
| 539 | +#else |
| 540 | + GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16 = no\n", __func__); |
| 541 | +#endif |
| 542 | + GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false"); |
| 543 | + GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false"); |
| 544 | + GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); |
544 | 545 |
|
545 | 546 | ctx->capture_next_compute = false;
|
546 | 547 | ctx->capture_started = false;
|
@@ -3153,10 +3154,12 @@ static void ggml_metal_encode_node(
|
3153 | 3154 | GGML_ASSERT(nqptg % 8 == 0);
|
3154 | 3155 | GGML_ASSERT(ncpsg % 32 == 0);
|
3155 | 3156 |
|
3156 |
| -#ifdef GGML_METAL_FORCE_FATTN_PREC_F32 |
| 3157 | +#ifdef GGML_METAL_FORCE_FATTN_PREC_F16 |
3157 | 3158 | const enum ggml_prec prec = GGML_PREC_DEFAULT;
|
3158 | 3159 | #else
|
3159 |
| - const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst); |
| 3160 | + // TODO: support both precisions |
| 3161 | + const enum ggml_prec prec = GGML_PREC_F32; |
| 3162 | + //const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst); |
3160 | 3163 | #endif
|
3161 | 3164 |
|
3162 | 3165 | const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2;
|
|
0 commit comments