mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
GatedDeltaNet: Skip redundant zeroing of buffers (Qwen3-Next)
This commit is contained in:
@@ -562,10 +562,10 @@ class GatedDeltaNet(Module):
|
||||
qkvz = self.qkvz_proj.forward(x, params)
|
||||
ba = self.ba_proj.forward(x, params)
|
||||
|
||||
mixed_qkv = torch.zeros((bsz, self.fdim_qkv, seqlen), dtype = torch.bfloat16, device = self.device)
|
||||
z = torch.zeros((bsz, seqlen, self.num_v_heads, self.v_head_dim), dtype = torch.bfloat16, device = self.device)
|
||||
beta = torch.zeros((bsz, seqlen, self.num_v_heads), dtype = torch.bfloat16, device = self.device)
|
||||
g = torch.zeros((bsz, seqlen, self.num_v_heads), dtype = torch.float, device = self.device)
|
||||
mixed_qkv = torch.empty((bsz, self.fdim_qkv, seqlen), dtype = torch.bfloat16, device = self.device)
|
||||
z = torch.empty((bsz, seqlen, self.num_v_heads, self.v_head_dim), dtype = torch.bfloat16, device = self.device)
|
||||
beta = torch.empty((bsz, seqlen, self.num_v_heads), dtype = torch.bfloat16, device = self.device)
|
||||
g = torch.empty((bsz, seqlen, self.num_v_heads), dtype = torch.float, device = self.device)
|
||||
|
||||
ext.gated_delta_net_fused_op(
|
||||
qkvz, ba,
|
||||
|
||||
Reference in New Issue
Block a user