mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Support Gemma3 (text)
This commit is contained in:
@@ -167,6 +167,7 @@ class ExLlamaV2ArchParams:
|
||||
# SWA required by architecture
|
||||
swa = False
|
||||
alternating_swa = False
|
||||
sliding_rope_theta = None
|
||||
|
||||
# Model only works with eager attention
|
||||
eager_attn_only = False
|
||||
@@ -476,6 +477,40 @@ class ExLlamaV2ArchParams:
|
||||
self.lm.alternating_swa = True
|
||||
self.lm.residual_stream_fp32 = True
|
||||
|
||||
# Gemma3
|
||||
|
||||
if arch_string == "Gemma3ForConditionalGeneration":
|
||||
arch_recognized = True
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_gemma2_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_gemma
|
||||
self.lm.keys.update({
|
||||
"lm_head": "model.embed_tokens",
|
||||
"norm_1": ".input_layernorm",
|
||||
"norm_1_post": ".post_attention_layernorm",
|
||||
"norm_2": ".pre_feedforward_layernorm",
|
||||
"norm_2_post": ".post_feedforward_layernorm",
|
||||
})
|
||||
self.lm_prefix = "language_model."
|
||||
self.lm.mlp_act_func = "gelu"
|
||||
self.lm.normalize_embeddings = True
|
||||
self.lm.norm_constant_bias = 1
|
||||
self.lm.requires_bos = True
|
||||
self.lm.alternating_swa = True
|
||||
self.lm.residual_stream_fp32 = True
|
||||
self.lm.sliding_rope_theta = 10000
|
||||
self.lm.default_vocab_size = 262208
|
||||
self.lm.default_rms_norm_eps = 1e-06
|
||||
self.lm.default_head_dim = 256
|
||||
self.lm.default_num_attention_heads = 8
|
||||
self.lm.default_num_key_value_heads = 4
|
||||
self.lm.default_use_qk_norm = True
|
||||
self.lm.default_sliding_window_pattern = 6
|
||||
self.lm.default_rope_theta = 1e6
|
||||
|
||||
# StarCoder2
|
||||
|
||||
if arch_string == "Starcoder2ForCausalLM":
|
||||
|
||||
@@ -99,6 +99,7 @@ class ExLlamaV2Config:
|
||||
norm_eps: float | None
|
||||
vocab_size: int
|
||||
rotary_embedding_base: float
|
||||
rotary_embedding_base_alt: float | None
|
||||
scale_long_factor: list[float] | None
|
||||
scale_short_factor: list[float] | None
|
||||
alt_rope_method: str | None
|
||||
@@ -352,6 +353,8 @@ class ExLlamaV2Config:
|
||||
opt_subkey = "text_config",
|
||||
)
|
||||
|
||||
self.rotary_embedding_base_alt = self.arch.lm.sliding_rope_theta
|
||||
|
||||
self.max_seq_len = read(
|
||||
read_config,
|
||||
int,
|
||||
|
||||
@@ -110,16 +110,18 @@ class ExLlamaV2:
|
||||
|
||||
if cfg.arch.lm.alternating_swa:
|
||||
swa = cfg.sliding_window if (layer_idx + 1) % cfg.sliding_window_pattern != 0 else 0
|
||||
if cfg.rotary_embedding_base_alt:
|
||||
rope_index = 1
|
||||
elif cfg.arch.lm.swa:
|
||||
swa = cfg.sliding_window
|
||||
else:
|
||||
swa = 0
|
||||
|
||||
if cfg.arch.lm.parallel_decoder_blocks:
|
||||
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx, sliding_window = swa)
|
||||
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx, sliding_window = swa, rope_index = rope_index)
|
||||
self.modules += [pd]
|
||||
else:
|
||||
attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa)
|
||||
attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa, rope_index = rope_index)
|
||||
if cfg.arch.lm.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)
|
||||
else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx)
|
||||
self.modules += [attn, mlp]
|
||||
|
||||
Reference in New Issue
Block a user