GLM-5 support (#1268)

This commit is contained in:
Kawrakow
2026-02-15 07:49:44 +01:00
committed by GitHub
parent f5fe33b7a9
commit 528cadb07b
9 changed files with 270 additions and 24 deletions

View File

@@ -677,8 +677,10 @@ static bool llama_kv_cache_init(
}
}
bool is_mla_attn = model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA;
bool split_cache = false;
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.arch != LLM_ARCH_DEEPSEEK2 && offload) {
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && !is_mla_attn && offload) {
cache.split_k_l.reserve(n_layer);
cache.split_v_l.reserve(n_layer);
split_cache = true;
@@ -718,7 +720,7 @@ static bool llama_kv_cache_init(
cache.ctxs.push_back(ctx);
}
if (model.arch == LLM_ARCH_DEEPSEEK2) {
if (is_mla_attn) {
bool have_wkv_b = true;
for (auto& l : model.layers) {
if (!l.wkv_b) {
@@ -744,7 +746,7 @@ static bool llama_kv_cache_init(
bool needs_v_cache = true;
cache.k_l.reserve(n_layer);
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) {
if (is_mla_attn && cparams.mla_attn) {
needs_v_cache = cparams.mla_attn == 1 && !cparams.flash_attn;
}
if (needs_v_cache) cache.v_l.reserve(n_layer);
@@ -760,7 +762,7 @@ static bool llama_kv_cache_init(
struct ggml_context * ctx = split_cache ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k;
ggml_tensor * v;
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) {
if (is_mla_attn && cparams.mla_attn) {
// DeepSeek MLA
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
@@ -841,7 +843,7 @@ static bool llama_kv_cache_init(
}
}
}
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
if (is_mla_attn && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer));
LLAMA_LOG_ERROR("%s: bailing out\n", __func__);
GGML_ABORT("fatal error");
@@ -1379,7 +1381,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
// general kv
LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
if (model.arch == LLM_ARCH_DEEPSEEK2) {
if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
@@ -1424,7 +1426,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
}
static void llm_prepare_mla(llama_model & model, int mla) {
if (model.arch != LLM_ARCH_DEEPSEEK2) return;
if (model.arch != LLM_ARCH_DEEPSEEK2 && model.arch != LLM_ARCH_GLM_DSA) return;
const auto& hparams = model.hparams;
const int n_layer = model.layers.size();
int n_to_compute = 0;
@@ -2048,7 +2050,7 @@ static bool llm_load_tensors(
}
}
if (model.arch == LLM_ARCH_DEEPSEEK2) {
if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA) {
llm_prepare_mla(model, mla_attn);
}
@@ -3735,7 +3737,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2 || lctx.model.arch == LLM_ARCH_GLM_DSA) { // not supported due to MLA
return 1;
}
@@ -4542,20 +4544,10 @@ struct llama_context * llama_init_from_model(
params.seed = time(NULL);
}
if (model->arch != LLM_ARCH_DEEPSEEK2 && cparams.mla_attn != 0) {
//LLAMA_LOG_WARN("=====================================================================\n");
//LLAMA_LOG_WARN(" MLA is only available for LLM_ARCH_DEEPSEEK2 -> turning off MLA\n");
//LLAMA_LOG_WARN("=====================================================================\n");
if (model->arch != LLM_ARCH_DEEPSEEK2 && model->arch != LLM_ARCH_GLM_DSA && cparams.mla_attn != 0) {
cparams.mla_attn = 0;
}
if (model->arch == LLM_ARCH_OPENAI_MOE && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
//if (cparams.split_mode_f16) {
// LLAMA_LOG_WARN("=====================================================================\n");
// LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n");
// LLAMA_LOG_WARN(" => changing cparams.split_mode_f16 to 'false'\n");
// LLAMA_LOG_WARN("=====================================================================\n");
// cparams.split_mode_f16 = false;
//}
if (cparams.reduce_type == GGML_TYPE_F16) {
LLAMA_LOG_WARN("=====================================================================\n");
LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n");
@@ -4569,7 +4561,7 @@ struct llama_context * llama_init_from_model(
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
if (model->arch == LLM_ARCH_DEEPSEEK2) {
if (model->arch == LLM_ARCH_DEEPSEEK2 || model->arch == LLM_ARCH_GLM_DSA) {
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
}
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
@@ -5020,6 +5012,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_SMOLLM3:
case LLM_ARCH_MISTRAL3:
case LLM_ARCH_GLM_DSA:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2