diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 21253e0b..72a5f2d6 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -1999,9 +1999,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ct const int nstages = cp_async_available(cc) ? c::nstages_target : 0; constexpr int ncols = ncols1 * ncols2; - constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. + constexpr int ntiles = ncols <= 8 && DKQ < 576 ? 1 : 2; // Number of tiles per warp. constexpr int cols_per_warp = ntiles * tile_B::I; - constexpr int nwarps_max_x = ncols / cols_per_warp; + constexpr int nwarps_max_x = (ncols + cols_per_warp - 1) / cols_per_warp; constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; @@ -2063,6 +2063,10 @@ template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; + if constexpr (DKQ == 576 && ncols2 <= 4) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + } else { + if constexpr (ncols2 <= 8) { if (Q->ne[1] <= 8/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); @@ -2081,6 +2085,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + } } template @@ -2156,8 +2161,15 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens return; } GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); - GGML_ASSERT(gqa_ratio % 16 == 0); - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } else if (gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } else { + GGML_ABORT("Unsupported GQA 576 x 512 case"); + } + //GGML_ASSERT(gqa_ratio % 16 == 0); + //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); //switch (Q->ne[0]) { // case 64: diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ab3353db..ae5a3507 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -178,6 +178,10 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te } if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { + if (Q->ne[0] == 576) { + int gqa_ratio = Q->ne[2]/K->ne[2]; + return (gqa_ratio % 4) == 0; + } return true; } diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 4cef9236..79e0a1d0 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -750,7 +750,7 @@ void llm_load_hparams( { if (hparams.n_head_kv() == 1) { int n_nead_kv = hparams.n_gqa(); - if (n_nead_kv%16 != 0 || hparams.n_embd_head_k != 576 || hparams.n_embd_head_v != 512 || + if (n_nead_kv%4 != 0 || hparams.n_embd_head_k != 576 || hparams.n_embd_head_v != 512 || hparams.n_rot != 64) { printf("==========================================================================\n"); printf("Detected incompatible DeepSeek model without a known way to fixc it.\n"); @@ -788,6 +788,7 @@ void llm_load_hparams( switch (hparams.n_layer) { case 27: model.type = e_model::MODEL_16B; break; + case 47: model.type = e_model::MODEL_30B_A3B; break; // GLM-4.7-Flash case 60: model.type = e_model::MODEL_236B; break; case 61: model.type = e_model::MODEL_671B; break; default: model.type = e_model::MODEL_UNKNOWN;