Optionally clamp hidden states (for Gemma2)

This commit is contained in:
turboderp
2024-07-06 11:55:23 +02:00
parent 8f5680dfca
commit adefba1973
4 changed files with 17 additions and 2 deletions

View File

@@ -124,6 +124,7 @@ class ExLlamaV2ArchParams:
self.alternating_swa = False
self.eager_attn_only = False
self.clamp_hidden_states = False
self.fused_qkv_altpack = False
@@ -370,6 +371,7 @@ class ExLlamaV2ArchParams:
self.pre_post_layernorm = True
self.alternating_swa = True
self.eager_attn_only = True
self.clamp_hidden_states = True
# StarCoder2

View File

@@ -973,6 +973,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
pass_lora_temp
)
if cfg.arch.clamp_hidden_states:
hidden_states.clamp_(-65504, 65504)
return hidden_states
@@ -1081,6 +1084,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
hidden_states = (attn_proj + residual) if self.has_residual else attn_proj
if cfg.arch.clamp_hidden_states:
hidden_states.clamp_(-65504, 65504)
if intermediates:
return {"post_norm": post_norm,
"attn_output": attn_output,

View File

@@ -260,6 +260,8 @@ class ExLlamaV2MLP(ExLlamaV2Module):
loras: list[ExLlamaV2Lora] | None = None,
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
cfg = self.model.config
if self.q_handle is None or intermediates:
return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs)
@@ -275,6 +277,9 @@ class ExLlamaV2MLP(ExLlamaV2Module):
pass_loras,
pass_lora_temp)
if cfg.arch.clamp_hidden_states:
hidden_states.clamp_(-65504, 65504)
return hidden_states
@@ -314,6 +319,9 @@ class ExLlamaV2MLP(ExLlamaV2Module):
down = self.post_layernorm.forward(down)
hidden_states = down + residual if self.has_residual else down
if cfg.arch.clamp_hidden_states:
hidden_states = hidden_states.clamp(-65504, 65504)
if intermediates:
return {"post_norm": post_norm,
"pre_down": y,

View File

@@ -120,8 +120,7 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module):
loras = None,
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
hidden_states[hidden_states == -float('inf')] = -65504.0
hidden_states[hidden_states == float('inf')] = 65504.0
hidden_states.clamp_(-65504.0, 65504.0)
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim = True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)