Support Gemma3 (text)

This commit is contained in:
turboderp
2025-03-14 23:45:48 +01:00
parent 565339101b
commit c0267e37fe
3 changed files with 42 additions and 2 deletions

View File

@@ -167,6 +167,7 @@ class ExLlamaV2ArchParams:
# SWA required by architecture
swa = False
alternating_swa = False
sliding_rope_theta = None
# Model only works with eager attention
eager_attn_only = False
@@ -476,6 +477,40 @@ class ExLlamaV2ArchParams:
self.lm.alternating_swa = True
self.lm.residual_stream_fp32 = True
# Gemma3
if arch_string == "Gemma3ForConditionalGeneration":
arch_recognized = True
self.lm.layer_keys += \
layer_keys_gemma2_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.lm.expect_keys += \
expect_keys_gemma
self.lm.keys.update({
"lm_head": "model.embed_tokens",
"norm_1": ".input_layernorm",
"norm_1_post": ".post_attention_layernorm",
"norm_2": ".pre_feedforward_layernorm",
"norm_2_post": ".post_feedforward_layernorm",
})
self.lm_prefix = "language_model."
self.lm.mlp_act_func = "gelu"
self.lm.normalize_embeddings = True
self.lm.norm_constant_bias = 1
self.lm.requires_bos = True
self.lm.alternating_swa = True
self.lm.residual_stream_fp32 = True
self.lm.sliding_rope_theta = 10000
self.lm.default_vocab_size = 262208
self.lm.default_rms_norm_eps = 1e-06
self.lm.default_head_dim = 256
self.lm.default_num_attention_heads = 8
self.lm.default_num_key_value_heads = 4
self.lm.default_use_qk_norm = True
self.lm.default_sliding_window_pattern = 6
self.lm.default_rope_theta = 1e6
# StarCoder2
if arch_string == "Starcoder2ForCausalLM":

View File

@@ -99,6 +99,7 @@ class ExLlamaV2Config:
norm_eps: float | None
vocab_size: int
rotary_embedding_base: float
rotary_embedding_base_alt: float | None
scale_long_factor: list[float] | None
scale_short_factor: list[float] | None
alt_rope_method: str | None
@@ -352,6 +353,8 @@ class ExLlamaV2Config:
opt_subkey = "text_config",
)
self.rotary_embedding_base_alt = self.arch.lm.sliding_rope_theta
self.max_seq_len = read(
read_config,
int,

View File

@@ -110,16 +110,18 @@ class ExLlamaV2:
if cfg.arch.lm.alternating_swa:
swa = cfg.sliding_window if (layer_idx + 1) % cfg.sliding_window_pattern != 0 else 0
if cfg.rotary_embedding_base_alt:
rope_index = 1
elif cfg.arch.lm.swa:
swa = cfg.sliding_window
else:
swa = 0
if cfg.arch.lm.parallel_decoder_blocks:
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx, sliding_window = swa)
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx, sliding_window = swa, rope_index = rope_index)
self.modules += [pd]
else:
attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa)
attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa, rope_index = rope_index)
if cfg.arch.lm.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)
else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx)
self.modules += [attn, mlp]