mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Respect norm_constant_bias in Q/K norms (Gemma3)
This commit is contained in:
@@ -68,6 +68,9 @@ class ExLlamaV2HeadNorm(ExLlamaV2Module):
|
||||
|
||||
assert self.weight.shape == (self.num_heads, self.head_dim), "Head norm tensor shape mismatch"
|
||||
|
||||
if self.archparams.norm_constant_bias != 0:
|
||||
self.weight += self.archparams.norm_constant_bias
|
||||
|
||||
|
||||
def unload(self):
|
||||
|
||||
@@ -84,8 +87,12 @@ class ExLlamaV2HeadNorm(ExLlamaV2Module):
|
||||
|
||||
def get_weight(self) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
if self.bias is not None: return self.weight, self.bias
|
||||
return self.weight
|
||||
w = self.weight.data
|
||||
if self.archparams.norm_constant_bias != 0:
|
||||
return w - self.archparams.norm_constant_bias
|
||||
|
||||
if self.bias is not None: return w, self.bias
|
||||
return w
|
||||
|
||||
|
||||
def weight_footprint(self) -> int:
|
||||
@@ -127,6 +134,7 @@ class ExLlamaV2HeadNorm(ExLlamaV2Module):
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward_torch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user