Add GLM4 architecture

This commit is contained in:
turboderp
2025-04-15 18:57:29 +02:00
parent b148bb42b8
commit de19cbcc59
2 changed files with 72 additions and 1 deletions

View File

@@ -674,6 +674,50 @@ class PromptFormat_cohere(PromptFormat):
return True
class PromptFormat_glm(PromptFormat):
description = "GLM4"
def __init__(self):
super().__init__()
pass
def default_system_prompt(self):
return \
f"""You are a helpful AI assistant."""
def first_prompt(self, sysprompt):
r = """[gMASK]<sop>"""
if sysprompt:
r += \
"""<|system|>\n""" + \
"""<|system_prompt|>"""
r += \
"""<|user|>\n""" + \
"""<|user_prompt|>""" + \
"""<|assistant|>\n"""
return r
def subs_prompt(self):
return \
"""<|user|>\n""" + \
"""<|user_prompt|>""" + \
"""<|assistant|>\n"""
def stop_conditions(self, tokenizer):
return \
[tokenizer.eos_token_id,
tokenizer.single_id("<|user|>"),
"""<|user|>""",
]
def encoding_options(self):
return True, False, True
def print_extra_newline(self):
return True
prompt_formats = \
{
"raw": PromptFormat_raw,
@@ -693,4 +737,5 @@ prompt_formats = \
"phi3": PromptFormat_phi3,
"granite": PromptFormat_granite,
"granite3": PromptFormat_granite3,
"glm": PromptFormat_glm
}

View File

@@ -16,6 +16,10 @@ layer_keys_gemma2_norms = [["input_layernorm"],
["post_feedforward_layernorm"]]
layer_keys_internlm2_norms = [["attention_norm"],
["ffn_norm"]]
layer_keys_glm4_norms = [["input_layernorm"],
["post_self_attn_layernorm"],
["post_attention_layernorm"],
["post_mlp_layernorm"]]
layer_keys_llama_attn = [["self_attn.q_proj"],
["self_attn.k_proj"],
["self_attn.v_proj"],
@@ -808,6 +812,28 @@ class ExLlamaV2ArchParams:
self.lm.expect_keys += \
expect_keys_llama
# GLM4
if arch_string == "Glm4ForCausalLM":
arch_recognized = True
self.lm.layer_keys += \
layer_keys_glm4_norms + \
layer_keys_llama_attn + \
layer_keys_phi3_mlp
self.lm.expect_keys += \
expect_keys_llama
self.lm.supports_tp = True
self.lm.rope_style = RopeStyle.GPTJ
self.lm.keys.update({
"fused_mlp_12": "gate_up_proj",
"lm_head": "model.embed_tokens",
"norm_1": ".input_layernorm",
"norm_1_post": ".post_self_attn_layernorm",
"norm_2": ".post_attention_layernorm",
"norm_2_post": ".post_mlp_layernorm",
})
self.lm.attention_bias_qkv = True
# Llama (default + fallback)
if arch_string != "LlamaForCausalLM" and not arch_recognized:
@@ -825,7 +851,7 @@ class ExLlamaV2ArchParams:
# Arch overrides
if read_config.get("attention_bias", False):
if read_config.get("attention_bias", False) and not (self.lm.attention_bias_qkv or self.lm.attention_bias_o):
self.lm.attention_bias_qkv = True
self.lm.attention_bias_o = True