mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Attn: Support headwise gate
This commit is contained in:
@@ -140,10 +140,11 @@ class Attention(Module):
|
||||
key_k: str | None = None,
|
||||
key_v: str | None = None,
|
||||
key_o: str | None = None,
|
||||
key_g: str | None = None,
|
||||
key_fused_qkv: str | None = None,
|
||||
qmap: str | None = None,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
sliding_window: int = -1,
|
||||
sliding_window: int = -1,
|
||||
logit_softcapping: float = 0.0,
|
||||
q_norm: RMSNorm | LayerNorm | None = None,
|
||||
k_norm: RMSNorm | LayerNorm | None = None,
|
||||
@@ -181,6 +182,7 @@ class Attention(Module):
|
||||
if self.num_kv_heads == 0:
|
||||
return
|
||||
|
||||
# Create q, k, v projections
|
||||
if key_fused_qkv:
|
||||
assert not interleaved_gate, "Attn: interleaved_gate not implemented for fused QKV tensor"
|
||||
fkey = f"{key}.{key_fused_qkv}"
|
||||
@@ -216,6 +218,7 @@ class Attention(Module):
|
||||
self.register_submodule(self.k_proj)
|
||||
self.register_submodule(self.v_proj)
|
||||
|
||||
# Create o proj
|
||||
if key_o:
|
||||
self.o_proj = Linear(config, f"{key}.{key_o}", num_q_heads * head_dim, hidden_size, qmap = qmap + ".o", out_dtype = out_dtype, qbits_mod_key = "o")
|
||||
self.register_submodule(self.o_proj)
|
||||
@@ -224,6 +227,7 @@ class Attention(Module):
|
||||
self.o_proj = o_proj
|
||||
self.register_submodule(self.o_proj)
|
||||
|
||||
# Register q/k norms
|
||||
if q_norm:
|
||||
assert k_norm, "Must have both Q and K norms, or neither"
|
||||
self.q_norm = q_norm
|
||||
@@ -243,6 +247,17 @@ class Attention(Module):
|
||||
self.norm_eps = 1e-6
|
||||
self.norm_constant_bias = 0.0
|
||||
|
||||
# Register headwise gate
|
||||
if key_g:
|
||||
assert not interleaved_gate, \
|
||||
"Cannot apply both interleaved and headwise gate"
|
||||
self.g_proj = Linear(config, f"{key}.{key_g}", hidden_size, num_q_heads, qmap = None, out_dtype = torch.half, pad_to = 1)
|
||||
self.headwise_gate = True
|
||||
self.register_submodule(self.g_proj)
|
||||
else:
|
||||
self.g_proj = None
|
||||
self.headwise_gate = False
|
||||
|
||||
self.caps.update({
|
||||
"kv_cache": True
|
||||
})
|
||||
@@ -365,6 +380,8 @@ class Attention(Module):
|
||||
if self.interleaved_gate:
|
||||
q, g = torch.chunk(q.view(bsz, q_len, -1, self.head_dim * 2), 2, dim = -1)
|
||||
g = g.reshape(bsz, q_len, -1)
|
||||
elif self.g_proj:
|
||||
g = self.g_proj.forward(x, params)
|
||||
else:
|
||||
g = None
|
||||
|
||||
@@ -518,6 +535,9 @@ class Attention(Module):
|
||||
v = v.transpose(1, 2)
|
||||
o = F.scaled_dot_product_attention(q, k, v, is_causal = causal, enable_gqa = self.gqa, scale = self.sm_scale)
|
||||
o = o.transpose(1, 2)
|
||||
|
||||
if self.headwise_gate: o *= g.sigmoid().unsqueeze(-1)
|
||||
o = o.view((bsz, seqlen, self.num_q_heads * self.head_dim))
|
||||
if self.interleaved_gate: o *= g.sigmoid()
|
||||
|
||||
o = self.project_o(o, bsz, seqlen, params)
|
||||
@@ -591,6 +611,7 @@ class Attention(Module):
|
||||
softcap = self.logit_softcapping
|
||||
)
|
||||
|
||||
if self.headwise_gate: o *= g.sigmoid().unsqueeze(-1)
|
||||
o = o.view((bsz, seqlen, self.num_q_heads * self.head_dim))
|
||||
if self.interleaved_gate: o *= g.sigmoid()
|
||||
|
||||
@@ -667,6 +688,7 @@ class Attention(Module):
|
||||
else:
|
||||
cache.update_layer(self.layer_idx, cache_seqlens, block_table, cache_k, cache_v, seqlen)
|
||||
|
||||
if self.headwise_gate: o *= g.sigmoid().unsqueeze(-1)
|
||||
o = o.view((bsz, seqlen, self.num_q_heads * self.head_dim))
|
||||
if self.interleaved_gate: o *= g.sigmoid()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user