diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cu b/ggml/src/ggml-cuda/fattn-mma-f16.cu index 01e63541..539d7728 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cu @@ -43,6 +43,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context case 128: ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); break; + case 192: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, ncols2>(ctx, dst); + break; case 256: ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); break; @@ -88,5 +91,5 @@ bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_con auto K = dst->src[1]; auto V = dst->src[1]; if (K->ne[0] != V->ne[0]) return false; - return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256; + return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 192 || K->ne[0] == 256; } diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index b2285fdd..ef557209 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -315,6 +315,38 @@ struct fattn_mma_f16_config<192, 128> { } }; +template <> +struct fattn_mma_f16_config<192, 192> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 1; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 32; + } +}; + template <> struct fattn_mma_f16_config<576, 512> { static constexpr int nbatch_fa = 32; @@ -2119,6 +2151,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst); return; } + if (K->ne[0] == 192 && V->ne[0] == 192) { + GGML_ASSERT(Q->ne[0] == 192); + GGML_ASSERT(gqa_ratio == 1); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst); + return; + } GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); GGML_ASSERT(gqa_ratio % 16 == 0); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);