exaone4: prefer layer_types over pattern for SWA layer mapping

This commit is contained in:
lesj0610
2026-02-12 01:47:42 +09:00
parent 701afb9294
commit 5c076e5f2a

View File

@@ -26,12 +26,48 @@ class Exaone4Config(Config):
self.hidden_size = self.read_cfg(int, "hidden_size", no_default)
self.num_q_heads = self.read_cfg(int, "num_attention_heads", no_default)
self.num_kv_heads = self.read_cfg(int, "num_key_value_heads", self.num_q_heads)
self.num_hidden_layers = self.read_cfg(int, "num_hidden_layers", no_default)
if not self.head_dim:
self.head_dim = self.hidden_size // self.num_q_heads
self.sliding_window = self.read_cfg(int, "sliding_window", -1)
self.sliding_window_pattern = self.read_cfg(str, "sliding_window_pattern", None)
layer_types = self.read_cfg(list, "layer_types", None)
if layer_types:
assert len(layer_types) == self.num_hidden_layers, \
"Length of layer_types key doesn't match number of hidden layers"
self.swa_pattern = []
for t in layer_types:
if t == "sliding_attention":
if self.sliding_window < 0:
raise ValueError(
"layer_types requests sliding_attention but sliding_window is disabled"
)
self.swa_pattern.append(self.sliding_window)
elif t == "full_attention":
self.swa_pattern.append(-1)
else:
raise ValueError(f"Unknown layer type in layer_types: {t}")
elif self.sliding_window_pattern:
if self.sliding_window < 0:
raise ValueError(
"sliding_window_pattern is set but sliding_window is disabled"
)
self.swa_pattern = [
self.sliding_window
if (
idx != self.num_hidden_layers - 1
and self.sliding_window_pattern[idx % len(self.sliding_window_pattern)] == "L"
)
else -1
for idx in range(self.num_hidden_layers)
]
else:
self.swa_pattern = [-1 for _ in range(self.num_hidden_layers)]
# MLP params
self.assert_cfg(str, "hidden_act", "silu", True)
@@ -41,7 +77,6 @@ class Exaone4Config(Config):
self.rms_norm_eps = self.read_cfg(float, "rms_norm_eps", no_default)
# Layers
self.num_hidden_layers = self.read_cfg(int, "num_hidden_layers", no_default)
self.tie_word_embeddings = self.read_cfg(bool, "tie_word_embeddings", False)
# RoPE
@@ -69,14 +104,6 @@ class Exaone4Model(Model):
self.first_block_idx = len(self.modules)
is_local = [
bool(
idx != config.num_hidden_layers - 1
and config.sliding_window_pattern[idx % len(config.sliding_window_pattern)] == "L"
)
for idx in range(config.num_hidden_layers)
]
self.modules += [
TransformerBlock(
config = config,
@@ -89,9 +116,9 @@ class Exaone4Model(Model):
head_dim = config.head_dim,
num_q_heads = config.num_q_heads,
num_kv_heads = config.num_kv_heads,
rope_settings = config.rope_settings if is_local[idx] else None,
rope_settings = config.rope_settings if config.swa_pattern[idx] >= 0 else None,
sm_scale = None,
sliding_window = config.sliding_window if is_local[idx] else -1,
sliding_window = config.swa_pattern[idx],
key_q = "q_proj",
key_k = "k_proj",
key_v = "v_proj",
@@ -177,4 +204,4 @@ class Exaone4Model(Model):
p += f"[|user|]\n"
p += f"{prompt}[|endofturn|]\n"
p += f"[|assistant|]\n"
return p
return p