mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
exaone4: prefer layer_types over pattern for SWA layer mapping
This commit is contained in:
@@ -26,12 +26,48 @@ class Exaone4Config(Config):
|
||||
self.hidden_size = self.read_cfg(int, "hidden_size", no_default)
|
||||
self.num_q_heads = self.read_cfg(int, "num_attention_heads", no_default)
|
||||
self.num_kv_heads = self.read_cfg(int, "num_key_value_heads", self.num_q_heads)
|
||||
self.num_hidden_layers = self.read_cfg(int, "num_hidden_layers", no_default)
|
||||
|
||||
if not self.head_dim:
|
||||
self.head_dim = self.hidden_size // self.num_q_heads
|
||||
|
||||
self.sliding_window = self.read_cfg(int, "sliding_window", -1)
|
||||
self.sliding_window_pattern = self.read_cfg(str, "sliding_window_pattern", None)
|
||||
layer_types = self.read_cfg(list, "layer_types", None)
|
||||
|
||||
if layer_types:
|
||||
assert len(layer_types) == self.num_hidden_layers, \
|
||||
"Length of layer_types key doesn't match number of hidden layers"
|
||||
self.swa_pattern = []
|
||||
for t in layer_types:
|
||||
if t == "sliding_attention":
|
||||
if self.sliding_window < 0:
|
||||
raise ValueError(
|
||||
"layer_types requests sliding_attention but sliding_window is disabled"
|
||||
)
|
||||
self.swa_pattern.append(self.sliding_window)
|
||||
elif t == "full_attention":
|
||||
self.swa_pattern.append(-1)
|
||||
else:
|
||||
raise ValueError(f"Unknown layer type in layer_types: {t}")
|
||||
|
||||
elif self.sliding_window_pattern:
|
||||
if self.sliding_window < 0:
|
||||
raise ValueError(
|
||||
"sliding_window_pattern is set but sliding_window is disabled"
|
||||
)
|
||||
self.swa_pattern = [
|
||||
self.sliding_window
|
||||
if (
|
||||
idx != self.num_hidden_layers - 1
|
||||
and self.sliding_window_pattern[idx % len(self.sliding_window_pattern)] == "L"
|
||||
)
|
||||
else -1
|
||||
for idx in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
else:
|
||||
self.swa_pattern = [-1 for _ in range(self.num_hidden_layers)]
|
||||
|
||||
# MLP params
|
||||
self.assert_cfg(str, "hidden_act", "silu", True)
|
||||
@@ -41,7 +77,6 @@ class Exaone4Config(Config):
|
||||
self.rms_norm_eps = self.read_cfg(float, "rms_norm_eps", no_default)
|
||||
|
||||
# Layers
|
||||
self.num_hidden_layers = self.read_cfg(int, "num_hidden_layers", no_default)
|
||||
self.tie_word_embeddings = self.read_cfg(bool, "tie_word_embeddings", False)
|
||||
|
||||
# RoPE
|
||||
@@ -69,14 +104,6 @@ class Exaone4Model(Model):
|
||||
|
||||
self.first_block_idx = len(self.modules)
|
||||
|
||||
is_local = [
|
||||
bool(
|
||||
idx != config.num_hidden_layers - 1
|
||||
and config.sliding_window_pattern[idx % len(config.sliding_window_pattern)] == "L"
|
||||
)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
self.modules += [
|
||||
TransformerBlock(
|
||||
config = config,
|
||||
@@ -89,9 +116,9 @@ class Exaone4Model(Model):
|
||||
head_dim = config.head_dim,
|
||||
num_q_heads = config.num_q_heads,
|
||||
num_kv_heads = config.num_kv_heads,
|
||||
rope_settings = config.rope_settings if is_local[idx] else None,
|
||||
rope_settings = config.rope_settings if config.swa_pattern[idx] >= 0 else None,
|
||||
sm_scale = None,
|
||||
sliding_window = config.sliding_window if is_local[idx] else -1,
|
||||
sliding_window = config.swa_pattern[idx],
|
||||
key_q = "q_proj",
|
||||
key_k = "k_proj",
|
||||
key_v = "v_proj",
|
||||
@@ -177,4 +204,4 @@ class Exaone4Model(Model):
|
||||
p += f"[|user|]\n"
|
||||
p += f"{prompt}[|endofturn|]\n"
|
||||
p += f"[|assistant|]\n"
|
||||
return p
|
||||
return p
|
||||
|
||||
Reference in New Issue
Block a user