GatedDeltaNet: Skip redundant split+cat and some casts (Qwen3.5)

This commit is contained in:
turboderp
2026-03-03 05:02:22 +01:00
parent 2965eec919
commit 5738bc62e5

View File

@@ -579,15 +579,14 @@ class GatedDeltaNet(Module):
)
else:
qkv = self.qkv_proj.forward(x, params)
z = self.z_proj.forward(x, params).view(bsz, seqlen, self.num_v_heads, self.v_head_dim).to(torch.bfloat16)
b = self.b_proj.forward(x, params).float()
a = self.a_proj.forward(x, params).float()
z = self.z_proj.forward(x, params).view(bsz, seqlen, self.num_v_heads, self.v_head_dim)
b = self.b_proj.forward(x, params)
a = self.a_proj.forward(x, params)
q, k, v = torch.split(qkv, [self.k_dim, self.k_dim, self.v_dim], dim = -1)
mixed_qkv = torch.cat((q, k, v), dim = -1).transpose(1, 2).contiguous().to(torch.bfloat16)
mixed_qkv = qkv.transpose(1, 2).to(torch.bfloat16).contiguous()
dt_bias = self.dt_bias.float().view(1, 1, self.num_v_heads)
a_log = self.a_log.float().view(1, 1, self.num_v_heads)
a_log = self.a_log.view(1, 1, self.num_v_heads)
beta = torch.sigmoid(b).to(torch.bfloat16)
g = (-F.softplus(a + dt_bias) * torch.exp(a_log)).to(torch.float)