diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 46e87d0..271c40d 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -9,6 +9,10 @@ layer_keys_gpt2_norms = [["ln_1"], ["ln_2"]] layer_keys_yi_norms = [["ln1", "input_layernorm"], ["ln2", "post_attention_layernorm"]] +layer_keys_gemma2_norms = [["input_layernorm"], + ["post_attention_layernorm"], + ["pre_feedforward_layernorm"], + ["post_feedforward_layernorm"]] layer_keys_internlm2_norms = [["attention_norm"], ["ffn_norm"]] layer_keys_llama_attn = [["self_attn.q_proj"], @@ -326,6 +330,44 @@ class ExLlamaV2ArchParams: self.mqa = False self.scale_attn_weights = False + # Gemma2 + + if arch_string == "Gemma2ForCausalLM": + arch_recognized = True + self.layer_keys += \ + layer_keys_gemma2_norms + \ + layer_keys_llama_attn + \ + layer_keys_llama_mlp + self.expect_keys += \ + expect_keys_gemma + self.norm_eps_key = "rms_norm_eps" + self.attention_bias_qkv = False + self.attention_bias_o = False + self.mlp_bias = False + self.mlp_gate = True + self.mlp_key_gate = ".mlp.gate_proj" + self.mlp_key_up = ".mlp.up_proj" + self.mlp_key_down = ".mlp.down_proj" + self.mlp_act_func = "gelu" + self.is_moe = False + self.norm = "rmsnorm" + self.lm_head_key = "model.embed_tokens" + self.normalize_embeddings = True + self.norm_key_1 = ".input_layernorm" + self.norm_key_1_post = ".post_attention_layernorm" + self.norm_key_2 = ".pre_feedforward_layernorm" + self.norm_key_2_post = ".post_feedforward_layernorm" + self.norm_constant_bias = 1 + self.parallel_decoder_blocks = False + self.requires_bos = True + self.rope_style = RopeStyle.NEOX + self.keymap = None + self.fused_qkv_key = None + self.mqa = False + self.scale_attn_weights = False + self.pre_post_layernorm = True + self.alternating_swa = True + # StarCoder2 if arch_string == "Starcoder2ForCausalLM":