mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
GatedDeltaNet: Skip redundant split+cat and some casts (Qwen3.5)
This commit is contained in:
@@ -579,15 +579,14 @@ class GatedDeltaNet(Module):
|
||||
)
|
||||
else:
|
||||
qkv = self.qkv_proj.forward(x, params)
|
||||
z = self.z_proj.forward(x, params).view(bsz, seqlen, self.num_v_heads, self.v_head_dim).to(torch.bfloat16)
|
||||
b = self.b_proj.forward(x, params).float()
|
||||
a = self.a_proj.forward(x, params).float()
|
||||
z = self.z_proj.forward(x, params).view(bsz, seqlen, self.num_v_heads, self.v_head_dim)
|
||||
b = self.b_proj.forward(x, params)
|
||||
a = self.a_proj.forward(x, params)
|
||||
|
||||
q, k, v = torch.split(qkv, [self.k_dim, self.k_dim, self.v_dim], dim = -1)
|
||||
mixed_qkv = torch.cat((q, k, v), dim = -1).transpose(1, 2).contiguous().to(torch.bfloat16)
|
||||
mixed_qkv = qkv.transpose(1, 2).to(torch.bfloat16).contiguous()
|
||||
|
||||
dt_bias = self.dt_bias.float().view(1, 1, self.num_v_heads)
|
||||
a_log = self.a_log.float().view(1, 1, self.num_v_heads)
|
||||
a_log = self.a_log.view(1, 1, self.num_v_heads)
|
||||
beta = torch.sigmoid(b).to(torch.bfloat16)
|
||||
g = (-F.softplus(a + dt_bias) * torch.exp(a_log)).to(torch.float)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user