GatedDeltaNet: Skip redundant zeroing of buffers (Qwen3-Next)

This commit is contained in:
turboderp
2026-03-03 05:01:16 +01:00
parent 410a43df22
commit 2965eec919

View File

@@ -562,10 +562,10 @@ class GatedDeltaNet(Module):
qkvz = self.qkvz_proj.forward(x, params)
ba = self.ba_proj.forward(x, params)
mixed_qkv = torch.zeros((bsz, self.fdim_qkv, seqlen), dtype = torch.bfloat16, device = self.device)
z = torch.zeros((bsz, seqlen, self.num_v_heads, self.v_head_dim), dtype = torch.bfloat16, device = self.device)
beta = torch.zeros((bsz, seqlen, self.num_v_heads), dtype = torch.bfloat16, device = self.device)
g = torch.zeros((bsz, seqlen, self.num_v_heads), dtype = torch.float, device = self.device)
mixed_qkv = torch.empty((bsz, self.fdim_qkv, seqlen), dtype = torch.bfloat16, device = self.device)
z = torch.empty((bsz, seqlen, self.num_v_heads, self.v_head_dim), dtype = torch.bfloat16, device = self.device)
beta = torch.empty((bsz, seqlen, self.num_v_heads), dtype = torch.bfloat16, device = self.device)
g = torch.empty((bsz, seqlen, self.num_v_heads), dtype = torch.float, device = self.device)
ext.gated_delta_net_fused_op(
qkvz, ba,