Allow text model to use Q/K norm while vision model doesn't

This commit is contained in:
turboderp
2025-03-14 23:44:05 +01:00
parent 07afc90788
commit 565339101b

View File

@@ -193,7 +193,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
self.v_proj = ExLlamaV2Linear(model, key + km["attn_v"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = ap.fused_qkv_altpack)
self.o_proj = ExLlamaV2Linear(model, key + km["attn_o"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth)
if cfg.use_qk_norm:
if cfg.use_qk_norm and not ap.is_vision:
self.q_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.q_norm", self.num_attention_heads, self.head_dim)
self.k_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.k_norm", self.num_key_value_heads, self.head_dim)
else:
@@ -210,7 +210,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
self.submodules += [self.pre_layernorm]
if self.post_layernorm:
self.submodules += [self.post_layernorm]
if cfg.use_qk_norm:
if self.q_norm is not None:
self.submodules += [self.q_norm, self.k_norm]
if cfg.attention_multiplier:
@@ -555,7 +555,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
q = q.view(batch_size, q_len, self.num_attention_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_key_value_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_key_value_heads, self.head_dim)
if cfg.use_qk_norm:
if self.q_norm is not None:
q = self.q_norm.forward(q)
k = self.k_norm.forward(k)
if self.archparams.rope_style != RopeStyle.NONE:
@@ -760,7 +760,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
q = [q_.view(batch_size, q_len, q_.shape[1] // self.head_dim, self.head_dim) for q_ in q]
k = [k_.view(batch_size, q_len, k_.shape[1] // self.head_dim, self.head_dim) for k_ in k]
v = [v_.view(batch_size, q_len, v_.shape[1] // self.head_dim, self.head_dim) for v_ in v]
if cfg.use_qk_norm:
if self.q_norm is not None:
assert False, "TP not implemented for QK norm" # TODO: ...
# q = self.q_norm.forward(q)
# k = self.k_norm.forward(k)
@@ -1431,7 +1431,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
# Apply Q/K norms
if cfg.use_qk_norm:
if self.q_norm is not None:
query_states = self.q_norm.forward(query_states)
key_states = self.k_norm.forward(key_states)