From 03c0629b3c7bbcc642decea8757fdc3cf1501dad Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 20 Jan 2026 07:58:31 +0000 Subject: [PATCH] Make FA work for mla != 0 --- ggml/src/ggml-cuda/fattn-new-mma.cu | 20 ++++++++++++++++---- ggml/src/ggml-cuda/fattn.cu | 2 +- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 21253e0b..72a5f2d6 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -1999,9 +1999,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ct const int nstages = cp_async_available(cc) ? c::nstages_target : 0; constexpr int ncols = ncols1 * ncols2; - constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. + constexpr int ntiles = ncols <= 8 && DKQ < 576 ? 1 : 2; // Number of tiles per warp. constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int nwarps_max_x = ncols / cols_per_warp; + constexpr int nwarps_max_x = (ncols + cols_per_warp - 1) / cols_per_warp; constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; @@ -2063,6 +2063,10 @@ template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; + if constexpr (DKQ == 576 && ncols2 <= 4) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + } else { + if constexpr (ncols2 <= 8) { if (Q->ne[1] <= 8/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); @@ -2081,6 +2085,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + } } template @@ -2156,8 +2161,15 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens return; } GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); - GGML_ASSERT(gqa_ratio % 16 == 0); - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } else if (gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } else { + GGML_ABORT("Unsupported GQA 576 x 512 case"); + } + //GGML_ASSERT(gqa_ratio % 16 == 0); + //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); //switch (Q->ne[0]) { // case 64: diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1efa69d7..ae5a3507 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -180,7 +180,7 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { if (Q->ne[0] == 576) { int gqa_ratio = Q->ne[2]/K->ne[2]; - return (gqa_ratio % 16) == 0; + return (gqa_ratio % 4) == 0; } return true; }