GLM-4.7-Flash support (#1168)

* GLM-4.7-Flash support

* Model type

* Make FA work for mla != 0
This commit is contained in:
Kawrakow
2026-01-20 12:46:52 +02:00
committed by GitHub
parent ef5f17940c
commit 132a01d25d
3 changed files with 22 additions and 5 deletions

View File

@@ -1999,9 +1999,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ct
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
constexpr int ncols = ncols1 * ncols2;
constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
constexpr int ntiles = ncols <= 8 && DKQ < 576 ? 1 : 2; // Number of tiles per warp.
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int nwarps_max_x = ncols / cols_per_warp;
constexpr int nwarps_max_x = (ncols + cols_per_warp - 1) / cols_per_warp;
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
@@ -2063,6 +2063,10 @@ template <int DKQ, int DV, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
if constexpr (DKQ == 576 && ncols2 <= 4) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 4, ncols2>(ctx, dst);
} else {
if constexpr (ncols2 <= 8) {
if (Q->ne[1] <= 8/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
@@ -2081,6 +2085,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
}
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
}
}
template <int DKQ, int DV>
@@ -2156,8 +2161,15 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} else if (gqa_ratio % 4 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
} else {
GGML_ABORT("Unsupported GQA 576 x 512 case");
}
//GGML_ASSERT(gqa_ratio % 16 == 0);
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
//switch (Q->ne[0]) {
// case 64:

View File

@@ -178,6 +178,10 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te
}
if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
if (Q->ne[0] == 576) {
int gqa_ratio = Q->ne[2]/K->ne[2];
return (gqa_ratio % 4) == 0;
}
return true;
}

View File

@@ -750,7 +750,7 @@ void llm_load_hparams(
{
if (hparams.n_head_kv() == 1) {
int n_nead_kv = hparams.n_gqa();
if (n_nead_kv%16 != 0 || hparams.n_embd_head_k != 576 || hparams.n_embd_head_v != 512 ||
if (n_nead_kv%4 != 0 || hparams.n_embd_head_k != 576 || hparams.n_embd_head_v != 512 ||
hparams.n_rot != 64) {
printf("==========================================================================\n");
printf("Detected incompatible DeepSeek model without a known way to fixc it.\n");
@@ -788,6 +788,7 @@ void llm_load_hparams(
switch (hparams.n_layer) {
case 27: model.type = e_model::MODEL_16B; break;
case 47: model.type = e_model::MODEL_30B_A3B; break; // GLM-4.7-Flash
case 60: model.type = e_model::MODEL_236B; break;
case 61: model.type = e_model::MODEL_671B; break;
default: model.type = e_model::MODEL_UNKNOWN;