mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-05 11:30:09 +00:00
GLM-5 support (#1268)
This commit is contained in:
@@ -677,8 +677,10 @@ static bool llama_kv_cache_init(
|
||||
}
|
||||
}
|
||||
|
||||
bool is_mla_attn = model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA;
|
||||
|
||||
bool split_cache = false;
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.arch != LLM_ARCH_DEEPSEEK2 && offload) {
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && !is_mla_attn && offload) {
|
||||
cache.split_k_l.reserve(n_layer);
|
||||
cache.split_v_l.reserve(n_layer);
|
||||
split_cache = true;
|
||||
@@ -718,7 +720,7 @@ static bool llama_kv_cache_init(
|
||||
cache.ctxs.push_back(ctx);
|
||||
}
|
||||
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2) {
|
||||
if (is_mla_attn) {
|
||||
bool have_wkv_b = true;
|
||||
for (auto& l : model.layers) {
|
||||
if (!l.wkv_b) {
|
||||
@@ -744,7 +746,7 @@ static bool llama_kv_cache_init(
|
||||
|
||||
bool needs_v_cache = true;
|
||||
cache.k_l.reserve(n_layer);
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) {
|
||||
if (is_mla_attn && cparams.mla_attn) {
|
||||
needs_v_cache = cparams.mla_attn == 1 && !cparams.flash_attn;
|
||||
}
|
||||
if (needs_v_cache) cache.v_l.reserve(n_layer);
|
||||
@@ -760,7 +762,7 @@ static bool llama_kv_cache_init(
|
||||
struct ggml_context * ctx = split_cache ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) {
|
||||
if (is_mla_attn && cparams.mla_attn) {
|
||||
// DeepSeek MLA
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
@@ -841,7 +843,7 @@ static bool llama_kv_cache_init(
|
||||
}
|
||||
}
|
||||
}
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
|
||||
if (is_mla_attn && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
|
||||
LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer));
|
||||
LLAMA_LOG_ERROR("%s: bailing out\n", __func__);
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -1379,7 +1381,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
// general kv
|
||||
LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
|
||||
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2) {
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA) {
|
||||
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
|
||||
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
|
||||
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
|
||||
@@ -1424,7 +1426,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
}
|
||||
|
||||
static void llm_prepare_mla(llama_model & model, int mla) {
|
||||
if (model.arch != LLM_ARCH_DEEPSEEK2) return;
|
||||
if (model.arch != LLM_ARCH_DEEPSEEK2 && model.arch != LLM_ARCH_GLM_DSA) return;
|
||||
const auto& hparams = model.hparams;
|
||||
const int n_layer = model.layers.size();
|
||||
int n_to_compute = 0;
|
||||
@@ -2048,7 +2050,7 @@ static bool llm_load_tensors(
|
||||
}
|
||||
}
|
||||
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2) {
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_GLM_DSA) {
|
||||
llm_prepare_mla(model, mla_attn);
|
||||
}
|
||||
|
||||
@@ -3735,7 +3737,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
|
||||
// apply K-shift if needed
|
||||
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
||||
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
|
||||
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2 || lctx.model.arch == LLM_ARCH_GLM_DSA) { // not supported due to MLA
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -4542,20 +4544,10 @@ struct llama_context * llama_init_from_model(
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
||||
if (model->arch != LLM_ARCH_DEEPSEEK2 && cparams.mla_attn != 0) {
|
||||
//LLAMA_LOG_WARN("=====================================================================\n");
|
||||
//LLAMA_LOG_WARN(" MLA is only available for LLM_ARCH_DEEPSEEK2 -> turning off MLA\n");
|
||||
//LLAMA_LOG_WARN("=====================================================================\n");
|
||||
if (model->arch != LLM_ARCH_DEEPSEEK2 && model->arch != LLM_ARCH_GLM_DSA && cparams.mla_attn != 0) {
|
||||
cparams.mla_attn = 0;
|
||||
}
|
||||
if (model->arch == LLM_ARCH_OPENAI_MOE && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
//if (cparams.split_mode_f16) {
|
||||
// LLAMA_LOG_WARN("=====================================================================\n");
|
||||
// LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n");
|
||||
// LLAMA_LOG_WARN(" => changing cparams.split_mode_f16 to 'false'\n");
|
||||
// LLAMA_LOG_WARN("=====================================================================\n");
|
||||
// cparams.split_mode_f16 = false;
|
||||
//}
|
||||
if (cparams.reduce_type == GGML_TYPE_F16) {
|
||||
LLAMA_LOG_WARN("=====================================================================\n");
|
||||
LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n");
|
||||
@@ -4569,7 +4561,7 @@ struct llama_context * llama_init_from_model(
|
||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
if (model->arch == LLM_ARCH_DEEPSEEK2) {
|
||||
if (model->arch == LLM_ARCH_DEEPSEEK2 || model->arch == LLM_ARCH_GLM_DSA) {
|
||||
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
|
||||
@@ -5020,6 +5012,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_ERNIE4_5_MOE:
|
||||
case LLM_ARCH_SMOLLM3:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
case LLM_ARCH_GLM_DSA:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
||||
Reference in New Issue
Block a user