BlockSparseMLP: Allow loading combined experts tensors also when gate and up are not fused

This commit is contained in:
turboderp
2026-03-01 03:13:56 +01:00
parent 4bdd22ea77
commit 489b3aab12

View File

@@ -159,6 +159,8 @@ class BlockSparseMLP(Module):
key_up: str | None = None,
key_gate: str | None = None,
key_down: str | None = None,
key_gate_split: str | None = None,
key_up_split: str | None = None,
key_gate_up_split: str | None = None,
key_down_split: str | None = None,
key_routing_gate: str | None = None,
@@ -181,6 +183,7 @@ class BlockSparseMLP(Module):
routing_first: int | None = None,
routing_last: int | None = None,
routing_device: int | None = None,
transposed_load: bool = True,
):
super().__init__(config, key, None)
@@ -252,38 +255,52 @@ class BlockSparseMLP(Module):
for idx in range(self.num_local_experts):
fkey_gate, fkey_up, fkey_down = (
f"{key}.{key_gate_up_split}" if key_gate_up_split else
f"{key}.{key_gate_split}" if key_gate_split else
None,
f"{key}.{key_gate_up_split}" if key_gate_up_split else
f"{key}.{key_up_split}" if key_up_split else
None,
f"{key}.{key_down_split}" if key_down_split else
None
)
gate = Linear(
config = config,
key = f"{key}.{key_gate}".replace("{expert_idx}", str(idx)),
fkey = f"{key}.{key_gate_up_split}",
fkey = fkey_gate,
fidx = idx,
frange = (0, intermediate_size),
frange = (0, intermediate_size) if key_gate_up_split else None,
in_features = hidden_size,
out_features = intermediate_size,
qmap = qmap + ".input",
out_dtype = self.interm_dtype
out_dtype = self.interm_dtype,
transposed_load = transposed_load,
)
up = Linear(
config = config,
key = f"{key}.{key_up}".replace("{expert_idx}", str(idx)),
fkey = f"{key}.{key_gate_up_split}",
fkey = fkey_up,
fidx = idx,
frange = (intermediate_size, intermediate_size * 2),
frange = (intermediate_size, intermediate_size * 2) if key_gate_up_split else None,
in_features = hidden_size,
out_features = intermediate_size,
qmap = qmap + ".input",
out_dtype = self.interm_dtype
out_dtype = self.interm_dtype,
transposed_load = transposed_load,
)
down = Linear(
config = config,
key = f"{key}.{key_down}".replace("{expert_idx}", str(idx)),
fkey = f"{key}.{key_down_split}",
fkey = fkey_down,
fidx = idx,
in_features = intermediate_size,
out_features = hidden_size,
qmap = qmap + f".{idx}.down",
out_dtype = self.out_dtype,
allow_input_padding = True,
transposed_load = transposed_load,
)
self.ups.append(up)