diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 0e7908c2..f5cb9854 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -2136,21 +2136,19 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - if (K->ne[0] == 128 && (gqa_ratio == 12 || gqa_ratio == 6)) { + if (K->ne[0] == 128) { GGML_ASSERT(Q->ne[0] == 128 && V->ne[0] == 128); - //GGML_ASSERT(Q->ne[1] <= 4); - //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, 128, 16>(ctx, dst); if (gqa_ratio == 12) { ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst); - } else { + } else if (gqa_ratio == 6) { ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 8>(ctx, dst); + } else if (gqa_ratio == 10) { + ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst); + } else { + GGML_ABORT("Not implemented"); } return; } - //if (K->ne[0] == 64 && V->ne[0] == 64) { - // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst); - // return; - //} if (K->ne[0] == 192 && V->ne[0] == 128) { GGML_ASSERT(Q->ne[0] == 192); //GGML_ASSERT(gqa_ratio == 1); // Haha, this assert was for DeepSeek. But now we have Mimo2, which has GQA > 1 diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 267968b0..7d47a81d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -90,7 +90,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (new_mma_available(cc) && K->ne[0] == 128 && V->ne[0] == 128 && Q->ne[0] == 128 && Q->ne[1] == 1 && - (Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6)) { + (Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6 || Q->ne[2] / K->ne[2] == 10)) { ggml_cuda_flash_attn_ext_mma_new(ctx, dst); return; }