Step 3.5: Fix TP split

This commit is contained in:
turboderp
2026-03-01 21:32:59 +01:00
parent 6386de7a9b
commit 08ca454ec0
4 changed files with 27 additions and 10 deletions

View File

@@ -145,7 +145,8 @@ class Step3_5Model(Model):
rms_norm_eps = config.rms_norm_eps,
constant_bias = 1.0,
),
out_dtype = torch.float
out_dtype = torch.float,
tp_split_norm = False
),
mlp_norm = RMSNorm(
config = config,

View File

@@ -153,9 +153,11 @@ class Attention(Module):
v_proj: Linear | Module | None = None,
kv_proj: Linear | Module | None = None,
o_proj: Linear | Module | None = None,
g_proj: Linear | Module | None = None,
interleaved_gate: bool = False,
use_cu_seqlens: bool = False,
post_rope_norm: bool = False
post_rope_norm: bool = False,
tp_split_norm: bool = True
):
super().__init__(config, key, None)
@@ -174,6 +176,7 @@ class Attention(Module):
self.interleaved_gate = interleaved_gate
self.use_cu_seqlens = use_cu_seqlens
self.post_rope_norm = post_rope_norm
self.tp_split_norm = tp_split_norm
if post_rope_norm:
assert q_norm is None and k_norm is None, \
@@ -255,8 +258,13 @@ class Attention(Module):
self.headwise_gate = True
self.register_submodule(self.g_proj)
else:
self.g_proj = None
self.headwise_gate = False
if g_proj:
self.g_proj = g_proj
self.headwise_gate = True
self.register_submodule(self.g_proj)
else:
self.g_proj = None
self.headwise_gate = False
self.caps.update({
"kv_cache": True
@@ -769,6 +777,7 @@ class Attention(Module):
"sliding_window": self.sliding_window,
"logit_softcapping": self.logit_softcapping,
"post_rope_norm": self.post_rope_norm,
"tp_split_norm": self.tp_split_norm,
},
"num_kv_heads": self.num_kv_heads,
**{name: _export(getattr(self, name, None)) for name in (
@@ -779,6 +788,7 @@ class Attention(Module):
"v_proj",
"kv_proj",
"o_proj",
"g_proj",
)},
"device": self.device,
"cache_layers": [
@@ -798,6 +808,7 @@ class Attention(Module):
head_dim = exported["kwargs"]["head_dim"]
n_gqa = exported["n_gqa"]
device = local_context["device"]
tp_split_norm = exported["kwargs"]["tp_split_norm"]
first, last, unit = plan[key]
assert unit == "heads"
num_kv_heads = last - first
@@ -805,6 +816,8 @@ class Attention(Module):
q_split = (True, first * head_dim * n_gqa, last * head_dim * n_gqa) \
if num_kv_heads else None
qh_split = (True, first * n_gqa, last * n_gqa) \
if num_kv_heads else None
kv_split = (True, first * head_dim, last * head_dim) \
if num_kv_heads else None
o_split = (False, first * head_dim * n_gqa, last * head_dim * n_gqa) \
@@ -825,10 +838,10 @@ class Attention(Module):
norm_k_split = (first, last) \
if num_kv_heads else None
# def _import(name):
# nonlocal exported, plan
# return exported[name]["cls"].tp_import(local_context, exported[name], plan) \
# if exported.get(name) else None
def _import(name):
nonlocal exported, plan
return exported[name]["cls"].tp_import(local_context, exported[name], plan) \
if exported.get(name) else None
def _import_split(name, split):
nonlocal exported, plan
@@ -840,13 +853,14 @@ class Attention(Module):
**exported["kwargs"],
num_q_heads = num_q_heads,
num_kv_heads = num_kv_heads,
q_norm = _import_split("q_norm", norm_q_split),
k_norm = _import_split("k_norm", norm_k_split),
q_norm = _import_split("q_norm", norm_q_split) if tp_split_norm else _import("q_norm"),
k_norm = _import_split("k_norm", norm_k_split) if tp_split_norm else _import("k_norm"),
q_proj = _import_split("q_proj", q_split),
k_proj = _import_split("k_proj", kv_split),
v_proj = _import_split("v_proj", kv_split),
kv_proj = _import_split("kv_proj", kv_split),
o_proj = _import_split("o_proj", o_split),
g_proj = _import_split("g_proj", qh_split),
)
if num_kv_heads:

View File

@@ -830,6 +830,7 @@ class BlockSparseMLP(Module):
"routed_scaling_factor": self.routed_scaling_factor,
"n_group": self.n_group,
"topk_group": self.topk_group,
"act_limit": self.act_limit
},
"routing_gate": _export(self.routing_gate),
"e_score_correction_bias": producer.send(self.e_score_correction_bias),

View File

@@ -718,6 +718,7 @@ class GatedMLP(Module):
"out_dtype": self.out_dtype,
"interm_dtype": self.interm_dtype,
"intermediate_split_size": self.intermediate_split_size,
"act_limit": self.act_limit
},
"gates": [_export(self.gates[i]) for i in range(self.num_slices)],
"ups": [_export(self.ups[i]) for i in range(self.num_slices)],