Support Gemma2

This commit is contained in:
turboderp
2024-07-06 07:14:47 +02:00
parent 01ce7bbb6e
commit bfc3cd9cf3

View File

@@ -9,6 +9,10 @@ layer_keys_gpt2_norms = [["ln_1"],
["ln_2"]]
layer_keys_yi_norms = [["ln1", "input_layernorm"],
["ln2", "post_attention_layernorm"]]
layer_keys_gemma2_norms = [["input_layernorm"],
["post_attention_layernorm"],
["pre_feedforward_layernorm"],
["post_feedforward_layernorm"]]
layer_keys_internlm2_norms = [["attention_norm"],
["ffn_norm"]]
layer_keys_llama_attn = [["self_attn.q_proj"],
@@ -326,6 +330,44 @@ class ExLlamaV2ArchParams:
self.mqa = False
self.scale_attn_weights = False
# Gemma2
if arch_string == "Gemma2ForCausalLM":
arch_recognized = True
self.layer_keys += \
layer_keys_gemma2_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
expect_keys_gemma
self.norm_eps_key = "rms_norm_eps"
self.attention_bias_qkv = False
self.attention_bias_o = False
self.mlp_bias = False
self.mlp_gate = True
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.mlp_act_func = "gelu"
self.is_moe = False
self.norm = "rmsnorm"
self.lm_head_key = "model.embed_tokens"
self.normalize_embeddings = True
self.norm_key_1 = ".input_layernorm"
self.norm_key_1_post = ".post_attention_layernorm"
self.norm_key_2 = ".pre_feedforward_layernorm"
self.norm_key_2_post = ".post_feedforward_layernorm"
self.norm_constant_bias = 1
self.parallel_decoder_blocks = False
self.requires_bos = True
self.rope_style = RopeStyle.NEOX
self.keymap = None
self.fused_qkv_key = None
self.mqa = False
self.scale_attn_weights = False
self.pre_post_layernorm = True
self.alternating_swa = True
# StarCoder2
if arch_string == "Starcoder2ForCausalLM":