diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 4f53ef48..f6c20d09 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -6,6 +6,7 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" #include "fattn-compat.cuh" +//#include "fattn-prev-mma-f16-interface.cuh" template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -87,19 +88,24 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con if constexpr (DKQ == 128 && DV == 128) { if (use_gqa_opt && gqa_ratio == 12) { - if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 1) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + if (Q->ne[1] <= 8) { + if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] == 1) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } - - if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; - } - - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); - return; + //else if (ggml_cuda_fattn_prev_mma_f16_is_supported(ctx, dst)) { + // ggml_cuda_flash_attn_ext_prev_mma_f16(ctx, dst); + // return; + //} } } @@ -485,7 +491,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int32_t precision = KQV->op_params[3]; const int32_t n_swa = KQV->op_params[4]; ggml_tensor local_dst, Kl, Vl, Ml;