From ed5990712d9e035cc4efa91d36b83bb53b0fff53 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 6 May 2025 09:32:02 +0300 Subject: [PATCH] CUDA WIP: support for FlashMLA-3 --- ggml/src/ggml-cuda.cu | 5 +++++ ggml/src/ggml-cuda/fattn.cu | 7 +++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 1f62b882..ff6e064c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3587,6 +3587,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) || (op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0); } + if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) { + const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc; + int gqa = op->src[0]->ne[2]/op->src[1]->ne[2]; + return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0); + } if (op->src[1]->ne[0] > 256) { return false; } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ea52fa02..0522cf04 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,6 +13,7 @@ #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" #include "fattn-mma-f16.cuh" +#include "fattn-new-mma.cuh" #include "fattn.cuh" #include @@ -519,10 +520,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // We need this because I haven't adapted the MMA kernels to work for different // K and V head sizes. - if (K->ne[0] != V->ne[0]) { + //if (K->ne[0] != V->ne[0]) { + if (!new_mma_available(cc)) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); + //ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); + ggml_cuda_flash_attn_ext_mma_new(ctx, dst); }