Separate transpose options for fused expert weights (account for differences between Qwen3Moe and Qwen3_5Moe)

This commit is contained in:
turboderp
2026-03-11 21:43:45 +01:00
parent e05f4636ee
commit 9db029ded5
3 changed files with 8 additions and 2 deletions

View File

@@ -381,7 +381,7 @@ class Qwen3_5BaseModel(Model):
key_down_split = "experts.down_proj",
key_routing_gate = "gate",
key_shared_gate = "shared_expert_gate",
transposed_load = False,
transpose_fused_weights = False,
qmap = "block.mlp",
interm_dtype = torch.half,
out_dtype = torch.float,

View File

@@ -188,6 +188,7 @@ class BlockSparseMLP(Module):
routing_last: int | None = None,
routing_device: int | None = None,
transposed_load: bool = True,
transpose_fused_weights: bool = True,
):
super().__init__(config, key, None)
@@ -285,6 +286,7 @@ class BlockSparseMLP(Module):
qmap = qmap + ".input",
out_dtype = self.interm_dtype,
transposed_load = transposed_load,
transpose_fused_weights = transpose_fused_weights,
)
up = Linear(
config = config,
@@ -297,6 +299,7 @@ class BlockSparseMLP(Module):
qmap = qmap + ".input",
out_dtype = self.interm_dtype,
transposed_load = transposed_load,
transpose_fused_weights = transpose_fused_weights,
)
down = Linear(
config = config,
@@ -309,6 +312,7 @@ class BlockSparseMLP(Module):
out_dtype = self.out_dtype,
allow_input_padding = True,
transposed_load = transposed_load,
transpose_fused_weights = transpose_fused_weights,
)
self.ups.append(up)

View File

@@ -38,6 +38,7 @@ class Linear(Module):
allow_input_padding: bool = False,
post_scale: float = 1.0,
transposed_load: bool = True,
transpose_fused_weights: bool = True,
):
super().__init__(config, key, qmap)
@@ -62,6 +63,7 @@ class Linear(Module):
self.out_dtype = out_dtype
self.post_scale = post_scale
self.transposed_load = transposed_load
self.transpose_fused_weights = transpose_fused_weights
assert self.in_features_unpadded == self.in_features or allow_input_padding, \
f"Input padding is not allowed for {self.key}, in_dim: {self.in_features_unpadded}, pad_to: {pad_to}"
@@ -168,7 +170,7 @@ class Linear(Module):
weight = self.config.stc.get_tensor(
self.fkey,
self.device,
transpose = self.transposed_load,
transpose = self.transpose_fused_weights,
no_defer = True,
fidx = self.fidx
)