GatedDeltaNet: Fix output projection no. input features

This commit is contained in:
turboderp
2026-03-02 16:35:26 +01:00
parent e12e6bd759
commit 5cb91c5505

View File

@@ -352,7 +352,14 @@ class GatedDeltaNet(Module):
self.b_proj = None
self.a_proj = None
self.o_proj = Linear(config, f"{key}.{key_o}", 2 * hidden_size, hidden_size, qmap = qmap + ".output", out_dtype = self.out_dtype)
self.o_proj = Linear(
config,
f"{key}.{key_o}",
self.v_head_dim * self.num_v_heads,
hidden_size,
qmap = qmap + ".output",
out_dtype = self.out_dtype
)
self.register_submodule(self.o_proj)
self.norm = GatedRMSNorm(config, f"{key}.{key_norm}", self.rms_norm_eps, out_dtype = torch.half)