From ffc9e48a6f527e58037ec944aaa6aba4b868d466 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 29 Jan 2026 17:06:47 +0000 Subject: [PATCH] CUDA FA --- ggml/src/ggml-cuda/fattn-new-mma.cu | 55 +++++++++++++++++++++++++++++ ggml/src/ggml-cuda/fattn.cu | 7 ++-- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 0e7908c2..38816bc4 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -378,6 +378,47 @@ struct fattn_mma_f16_config<576, 512> { } }; +template <> +struct fattn_mma_f16_config<1088, 1024> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 8; + static constexpr bool Q_in_reg = false; + static constexpr int nstages_target = 1; + + static int get_nbatch_K2_host([[maybe_unused]] const int cc, [[maybe_unused]] const int ncols) { + return 64; + } + + static constexpr __device__ int get_nbatch_K2_device([[maybe_unused]] int ncols) { + return 64; + } + + static int get_nbatch_V2_host([[maybe_unused]] const int cc, [[maybe_unused]] const int ncols) { + return 64; + //if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) { + // return ncols <= 16 ? 64 : 128; + //} + //return ncols <= 16 ? 256 : 128; + } + + static constexpr __device__ int get_nbatch_V2_device([[maybe_unused]] int ncols) { + return 64; +//#if __CUDA_ARCH__ == CC_TURING +// return ncols <= 16 ? 64 : 128; +//#else +// return ncols <= 16 ? 256 : 128; +//#endif // __CUDA_ARCH__ == CC_TURING + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 64; //128; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 64; //128; + } +}; + // ------------------------------------------------------------------------------------------------------------------ // The compiler is always able to unroll loops if they contain continue expressions. @@ -2165,6 +2206,20 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst); return; } + if (Q->ne[0] == 1088 && K->ne[0] == 1088 && V->ne[0] == 1024) { + GGML_ASSERT(gqa_ratio == 20); + if (Q->ne[1] <= 4) { + if (ggml_cuda_info().devices[ctx.device].cc >= CC_ADA_LOVELACE) { + ggml_cuda_flash_attn_ext_mma_f16_case<1088, 1024, 1, 16>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_mma_f16_case<1088, 1024, 1, 32>(ctx, dst); + } + return; + } + //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<1088, 1024, 4>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<1088, 1024, 4, 4>(ctx, dst); + return; + } GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); if (gqa_ratio == 20 && Q->ne[1] <= 4 && K->ne[1] >= 2048) { if (ggml_cuda_info().devices[ctx.device].cc >= CC_ADA_LOVELACE) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 267968b0..9413588b 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -114,7 +114,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // so no other implementation works. // - if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { + if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 1088 && V->ne[0] == 1024) || + (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { //printf("Using ggml_cuda_flash_attn_ext_mma_new\n"); ggml_cuda_flash_attn_ext_mma_new(ctx, dst); return; @@ -185,8 +186,8 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst); } - if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { - if (Q->ne[0] == 576) { + if (new_mma_available(cc) && (Q->ne[0] == 576 || Q->ne[0] == 1088 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { + if (Q->ne[0] == 576 || Q->ne[0] == 1088) { int gqa_ratio = Q->ne[2]/K->ne[2]; return (gqa_ratio % 4) == 0; }