Attn: Support headwise gate

This commit is contained in:
turboderp
2026-03-01 03:12:03 +01:00
parent 447c8bb522
commit f7ccb524e7

View File

@@ -140,10 +140,11 @@ class Attention(Module):
key_k: str | None = None,
key_v: str | None = None,
key_o: str | None = None,
key_g: str | None = None,
key_fused_qkv: str | None = None,
qmap: str | None = None,
out_dtype: torch.dtype | None = None,
sliding_window: int = -1,
sliding_window: int = -1,
logit_softcapping: float = 0.0,
q_norm: RMSNorm | LayerNorm | None = None,
k_norm: RMSNorm | LayerNorm | None = None,
@@ -181,6 +182,7 @@ class Attention(Module):
if self.num_kv_heads == 0:
return
# Create q, k, v projections
if key_fused_qkv:
assert not interleaved_gate, "Attn: interleaved_gate not implemented for fused QKV tensor"
fkey = f"{key}.{key_fused_qkv}"
@@ -216,6 +218,7 @@ class Attention(Module):
self.register_submodule(self.k_proj)
self.register_submodule(self.v_proj)
# Create o proj
if key_o:
self.o_proj = Linear(config, f"{key}.{key_o}", num_q_heads * head_dim, hidden_size, qmap = qmap + ".o", out_dtype = out_dtype, qbits_mod_key = "o")
self.register_submodule(self.o_proj)
@@ -224,6 +227,7 @@ class Attention(Module):
self.o_proj = o_proj
self.register_submodule(self.o_proj)
# Register q/k norms
if q_norm:
assert k_norm, "Must have both Q and K norms, or neither"
self.q_norm = q_norm
@@ -243,6 +247,17 @@ class Attention(Module):
self.norm_eps = 1e-6
self.norm_constant_bias = 0.0
# Register headwise gate
if key_g:
assert not interleaved_gate, \
"Cannot apply both interleaved and headwise gate"
self.g_proj = Linear(config, f"{key}.{key_g}", hidden_size, num_q_heads, qmap = None, out_dtype = torch.half, pad_to = 1)
self.headwise_gate = True
self.register_submodule(self.g_proj)
else:
self.g_proj = None
self.headwise_gate = False
self.caps.update({
"kv_cache": True
})
@@ -365,6 +380,8 @@ class Attention(Module):
if self.interleaved_gate:
q, g = torch.chunk(q.view(bsz, q_len, -1, self.head_dim * 2), 2, dim = -1)
g = g.reshape(bsz, q_len, -1)
elif self.g_proj:
g = self.g_proj.forward(x, params)
else:
g = None
@@ -518,6 +535,9 @@ class Attention(Module):
v = v.transpose(1, 2)
o = F.scaled_dot_product_attention(q, k, v, is_causal = causal, enable_gqa = self.gqa, scale = self.sm_scale)
o = o.transpose(1, 2)
if self.headwise_gate: o *= g.sigmoid().unsqueeze(-1)
o = o.view((bsz, seqlen, self.num_q_heads * self.head_dim))
if self.interleaved_gate: o *= g.sigmoid()
o = self.project_o(o, bsz, seqlen, params)
@@ -591,6 +611,7 @@ class Attention(Module):
softcap = self.logit_softcapping
)
if self.headwise_gate: o *= g.sigmoid().unsqueeze(-1)
o = o.view((bsz, seqlen, self.num_q_heads * self.head_dim))
if self.interleaved_gate: o *= g.sigmoid()
@@ -667,6 +688,7 @@ class Attention(Module):
else:
cache.update_layer(self.layer_idx, cache_seqlens, block_table, cache_k, cache_v, seqlen)
if self.headwise_gate: o *= g.sigmoid().unsqueeze(-1)
o = o.view((bsz, seqlen, self.num_q_heads * self.head_dim))
if self.interleaved_gate: o *= g.sigmoid()