diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 8400a725..52a14639 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -2148,11 +2148,15 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - if (K->ne[0] == 128 && gqa_ratio == 12) { + if (K->ne[0] == 128 && (gqa_ratio == 12 || gqa_ratio == 6)) { GGML_ASSERT(Q->ne[0] == 128 && V->ne[0] == 128); //GGML_ASSERT(Q->ne[1] <= 4); //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, 128, 16>(ctx, dst); - ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst); + if (gqa_ratio == 12) { + ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 8>(ctx, dst); + } return; } //if (K->ne[0] == 64 && V->ne[0] == 64) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 6f1c766c..267968b0 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -90,7 +90,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (new_mma_available(cc) && K->ne[0] == 128 && V->ne[0] == 128 && Q->ne[0] == 128 && Q->ne[1] == 1 && - Q->ne[2] / K->ne[2] == 12) { + (Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6)) { ggml_cuda_flash_attn_ext_mma_new(ctx, dst); return; }