Respect norm_constant_bias in Q/K norms (Gemma3)

This commit is contained in:
turboderp
2025-03-14 23:17:50 +01:00
parent 4b5dbecdc1
commit b6c1912f29

View File

@@ -68,6 +68,9 @@ class ExLlamaV2HeadNorm(ExLlamaV2Module):
assert self.weight.shape == (self.num_heads, self.head_dim), "Head norm tensor shape mismatch"
if self.archparams.norm_constant_bias != 0:
self.weight += self.archparams.norm_constant_bias
def unload(self):
@@ -84,8 +87,12 @@ class ExLlamaV2HeadNorm(ExLlamaV2Module):
def get_weight(self) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.bias is not None: return self.weight, self.bias
return self.weight
w = self.weight.data
if self.archparams.norm_constant_bias != 0:
return w - self.archparams.norm_constant_bias
if self.bias is not None: return w, self.bias
return w
def weight_footprint(self) -> int:
@@ -127,6 +134,7 @@ class ExLlamaV2HeadNorm(ExLlamaV2Module):
else:
return hidden_states
def forward_torch(
self,
hidden_states: torch.Tensor,