diff --git a/exllamav3/modules/gated_delta_net.py b/exllamav3/modules/gated_delta_net.py index 5913fb8..29c325c 100644 --- a/exllamav3/modules/gated_delta_net.py +++ b/exllamav3/modules/gated_delta_net.py @@ -352,7 +352,14 @@ class GatedDeltaNet(Module): self.b_proj = None self.a_proj = None - self.o_proj = Linear(config, f"{key}.{key_o}", 2 * hidden_size, hidden_size, qmap = qmap + ".output", out_dtype = self.out_dtype) + self.o_proj = Linear( + config, + f"{key}.{key_o}", + self.v_head_dim * self.num_v_heads, + hidden_size, + qmap = qmap + ".output", + out_dtype = self.out_dtype + ) self.register_submodule(self.o_proj) self.norm = GatedRMSNorm(config, f"{key}.{key_norm}", self.rms_norm_eps, out_dtype = torch.half)