mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
Step 3.5: Fix TP split
This commit is contained in:
@@ -145,7 +145,8 @@ class Step3_5Model(Model):
|
||||
rms_norm_eps = config.rms_norm_eps,
|
||||
constant_bias = 1.0,
|
||||
),
|
||||
out_dtype = torch.float
|
||||
out_dtype = torch.float,
|
||||
tp_split_norm = False
|
||||
),
|
||||
mlp_norm = RMSNorm(
|
||||
config = config,
|
||||
|
||||
@@ -153,9 +153,11 @@ class Attention(Module):
|
||||
v_proj: Linear | Module | None = None,
|
||||
kv_proj: Linear | Module | None = None,
|
||||
o_proj: Linear | Module | None = None,
|
||||
g_proj: Linear | Module | None = None,
|
||||
interleaved_gate: bool = False,
|
||||
use_cu_seqlens: bool = False,
|
||||
post_rope_norm: bool = False
|
||||
post_rope_norm: bool = False,
|
||||
tp_split_norm: bool = True
|
||||
):
|
||||
super().__init__(config, key, None)
|
||||
|
||||
@@ -174,6 +176,7 @@ class Attention(Module):
|
||||
self.interleaved_gate = interleaved_gate
|
||||
self.use_cu_seqlens = use_cu_seqlens
|
||||
self.post_rope_norm = post_rope_norm
|
||||
self.tp_split_norm = tp_split_norm
|
||||
|
||||
if post_rope_norm:
|
||||
assert q_norm is None and k_norm is None, \
|
||||
@@ -255,8 +258,13 @@ class Attention(Module):
|
||||
self.headwise_gate = True
|
||||
self.register_submodule(self.g_proj)
|
||||
else:
|
||||
self.g_proj = None
|
||||
self.headwise_gate = False
|
||||
if g_proj:
|
||||
self.g_proj = g_proj
|
||||
self.headwise_gate = True
|
||||
self.register_submodule(self.g_proj)
|
||||
else:
|
||||
self.g_proj = None
|
||||
self.headwise_gate = False
|
||||
|
||||
self.caps.update({
|
||||
"kv_cache": True
|
||||
@@ -769,6 +777,7 @@ class Attention(Module):
|
||||
"sliding_window": self.sliding_window,
|
||||
"logit_softcapping": self.logit_softcapping,
|
||||
"post_rope_norm": self.post_rope_norm,
|
||||
"tp_split_norm": self.tp_split_norm,
|
||||
},
|
||||
"num_kv_heads": self.num_kv_heads,
|
||||
**{name: _export(getattr(self, name, None)) for name in (
|
||||
@@ -779,6 +788,7 @@ class Attention(Module):
|
||||
"v_proj",
|
||||
"kv_proj",
|
||||
"o_proj",
|
||||
"g_proj",
|
||||
)},
|
||||
"device": self.device,
|
||||
"cache_layers": [
|
||||
@@ -798,6 +808,7 @@ class Attention(Module):
|
||||
head_dim = exported["kwargs"]["head_dim"]
|
||||
n_gqa = exported["n_gqa"]
|
||||
device = local_context["device"]
|
||||
tp_split_norm = exported["kwargs"]["tp_split_norm"]
|
||||
first, last, unit = plan[key]
|
||||
assert unit == "heads"
|
||||
num_kv_heads = last - first
|
||||
@@ -805,6 +816,8 @@ class Attention(Module):
|
||||
|
||||
q_split = (True, first * head_dim * n_gqa, last * head_dim * n_gqa) \
|
||||
if num_kv_heads else None
|
||||
qh_split = (True, first * n_gqa, last * n_gqa) \
|
||||
if num_kv_heads else None
|
||||
kv_split = (True, first * head_dim, last * head_dim) \
|
||||
if num_kv_heads else None
|
||||
o_split = (False, first * head_dim * n_gqa, last * head_dim * n_gqa) \
|
||||
@@ -825,10 +838,10 @@ class Attention(Module):
|
||||
norm_k_split = (first, last) \
|
||||
if num_kv_heads else None
|
||||
|
||||
# def _import(name):
|
||||
# nonlocal exported, plan
|
||||
# return exported[name]["cls"].tp_import(local_context, exported[name], plan) \
|
||||
# if exported.get(name) else None
|
||||
def _import(name):
|
||||
nonlocal exported, plan
|
||||
return exported[name]["cls"].tp_import(local_context, exported[name], plan) \
|
||||
if exported.get(name) else None
|
||||
|
||||
def _import_split(name, split):
|
||||
nonlocal exported, plan
|
||||
@@ -840,13 +853,14 @@ class Attention(Module):
|
||||
**exported["kwargs"],
|
||||
num_q_heads = num_q_heads,
|
||||
num_kv_heads = num_kv_heads,
|
||||
q_norm = _import_split("q_norm", norm_q_split),
|
||||
k_norm = _import_split("k_norm", norm_k_split),
|
||||
q_norm = _import_split("q_norm", norm_q_split) if tp_split_norm else _import("q_norm"),
|
||||
k_norm = _import_split("k_norm", norm_k_split) if tp_split_norm else _import("k_norm"),
|
||||
q_proj = _import_split("q_proj", q_split),
|
||||
k_proj = _import_split("k_proj", kv_split),
|
||||
v_proj = _import_split("v_proj", kv_split),
|
||||
kv_proj = _import_split("kv_proj", kv_split),
|
||||
o_proj = _import_split("o_proj", o_split),
|
||||
g_proj = _import_split("g_proj", qh_split),
|
||||
)
|
||||
|
||||
if num_kv_heads:
|
||||
|
||||
@@ -830,6 +830,7 @@ class BlockSparseMLP(Module):
|
||||
"routed_scaling_factor": self.routed_scaling_factor,
|
||||
"n_group": self.n_group,
|
||||
"topk_group": self.topk_group,
|
||||
"act_limit": self.act_limit
|
||||
},
|
||||
"routing_gate": _export(self.routing_gate),
|
||||
"e_score_correction_bias": producer.send(self.e_score_correction_bias),
|
||||
|
||||
@@ -718,6 +718,7 @@ class GatedMLP(Module):
|
||||
"out_dtype": self.out_dtype,
|
||||
"interm_dtype": self.interm_dtype,
|
||||
"intermediate_split_size": self.intermediate_split_size,
|
||||
"act_limit": self.act_limit
|
||||
},
|
||||
"gates": [_export(self.gates[i]) for i in range(self.num_slices)],
|
||||
"ups": [_export(self.ups[i]) for i in range(self.num_slices)],
|
||||
|
||||
Reference in New Issue
Block a user