mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
Optionally clamp hidden states (for Gemma2)
This commit is contained in:
@@ -124,6 +124,7 @@ class ExLlamaV2ArchParams:
|
||||
self.alternating_swa = False
|
||||
|
||||
self.eager_attn_only = False
|
||||
self.clamp_hidden_states = False
|
||||
|
||||
self.fused_qkv_altpack = False
|
||||
|
||||
@@ -370,6 +371,7 @@ class ExLlamaV2ArchParams:
|
||||
self.pre_post_layernorm = True
|
||||
self.alternating_swa = True
|
||||
self.eager_attn_only = True
|
||||
self.clamp_hidden_states = True
|
||||
|
||||
# StarCoder2
|
||||
|
||||
|
||||
@@ -973,6 +973,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
pass_lora_temp
|
||||
)
|
||||
|
||||
if cfg.arch.clamp_hidden_states:
|
||||
hidden_states.clamp_(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -1081,6 +1084,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
hidden_states = (attn_proj + residual) if self.has_residual else attn_proj
|
||||
|
||||
if cfg.arch.clamp_hidden_states:
|
||||
hidden_states.clamp_(-65504, 65504)
|
||||
|
||||
if intermediates:
|
||||
return {"post_norm": post_norm,
|
||||
"attn_output": attn_output,
|
||||
|
||||
@@ -260,6 +260,8 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
loras: list[ExLlamaV2Lora] | None = None,
|
||||
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
|
||||
|
||||
cfg = self.model.config
|
||||
|
||||
if self.q_handle is None or intermediates:
|
||||
return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs)
|
||||
|
||||
@@ -275,6 +277,9 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
pass_loras,
|
||||
pass_lora_temp)
|
||||
|
||||
if cfg.arch.clamp_hidden_states:
|
||||
hidden_states.clamp_(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -314,6 +319,9 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
down = self.post_layernorm.forward(down)
|
||||
hidden_states = down + residual if self.has_residual else down
|
||||
|
||||
if cfg.arch.clamp_hidden_states:
|
||||
hidden_states = hidden_states.clamp(-65504, 65504)
|
||||
|
||||
if intermediates:
|
||||
return {"post_norm": post_norm,
|
||||
"pre_down": y,
|
||||
|
||||
@@ -120,8 +120,7 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
|
||||
loras = None,
|
||||
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
|
||||
|
||||
hidden_states[hidden_states == -float('inf')] = -65504.0
|
||||
hidden_states[hidden_states == float('inf')] = 65504.0
|
||||
hidden_states.clamp_(-65504.0, 65504.0)
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim = True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
Reference in New Issue
Block a user