From 5738bc62e52d379cab8065fcf72bd7d99070835e Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 3 Mar 2026 05:02:22 +0100 Subject: [PATCH] GatedDeltaNet: Skip redundant split+cat and some casts (Qwen3.5) --- exllamav3/modules/gated_delta_net.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/exllamav3/modules/gated_delta_net.py b/exllamav3/modules/gated_delta_net.py index 319f856..8638fe2 100644 --- a/exllamav3/modules/gated_delta_net.py +++ b/exllamav3/modules/gated_delta_net.py @@ -579,15 +579,14 @@ class GatedDeltaNet(Module): ) else: qkv = self.qkv_proj.forward(x, params) - z = self.z_proj.forward(x, params).view(bsz, seqlen, self.num_v_heads, self.v_head_dim).to(torch.bfloat16) - b = self.b_proj.forward(x, params).float() - a = self.a_proj.forward(x, params).float() + z = self.z_proj.forward(x, params).view(bsz, seqlen, self.num_v_heads, self.v_head_dim) + b = self.b_proj.forward(x, params) + a = self.a_proj.forward(x, params) - q, k, v = torch.split(qkv, [self.k_dim, self.k_dim, self.v_dim], dim = -1) - mixed_qkv = torch.cat((q, k, v), dim = -1).transpose(1, 2).contiguous().to(torch.bfloat16) + mixed_qkv = qkv.transpose(1, 2).to(torch.bfloat16).contiguous() dt_bias = self.dt_bias.float().view(1, 1, self.num_v_heads) - a_log = self.a_log.float().view(1, 1, self.num_v_heads) + a_log = self.a_log.view(1, 1, self.num_v_heads) beta = torch.sigmoid(b).to(torch.bfloat16) g = (-F.softplus(a + dt_bias) * torch.exp(a_log)).to(torch.float)