Skip to content

Commit c262bed

Browse files
CUDA: Prefer vector flash decoding kernel for Gemma models (#12738)
* Prefer vector flash decoding kernel for Gemma models Vector flash decoding kernel was not being picked for models with head dimension 256. Gemma models are in this category. Removing this limit improves e2e performance by upto 12% in gen phase throughput for Gemm models. * Update ggml/src/ggml-cuda/fattn.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
1 parent 5dd5d1a commit c262bed

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ggml/src/ggml-cuda/fattn.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
299299
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
300300
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
301301
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
302-
const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
302+
const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
303303
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
304304
if (prec == GGML_PREC_DEFAULT) {
305305
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);

0 commit comments

Comments
 (0)