diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index b2bf903..066befe 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -124,6 +124,7 @@ class ExLlamaV2ArchParams: self.alternating_swa = False self.eager_attn_only = False + self.clamp_hidden_states = False self.fused_qkv_altpack = False @@ -370,6 +371,7 @@ class ExLlamaV2ArchParams: self.pre_post_layernorm = True self.alternating_swa = True self.eager_attn_only = True + self.clamp_hidden_states = True # StarCoder2 diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 40f11b6..d8c508e 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -973,6 +973,9 @@ class ExLlamaV2Attention(ExLlamaV2Module): pass_lora_temp ) + if cfg.arch.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + return hidden_states @@ -1081,6 +1084,9 @@ class ExLlamaV2Attention(ExLlamaV2Module): hidden_states = (attn_proj + residual) if self.has_residual else attn_proj + if cfg.arch.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + if intermediates: return {"post_norm": post_norm, "attn_output": attn_output, diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index 51510d2..f41ea96 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -260,6 +260,8 @@ class ExLlamaV2MLP(ExLlamaV2Module): loras: list[ExLlamaV2Lora] | None = None, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: + cfg = self.model.config + if self.q_handle is None or intermediates: return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs) @@ -275,6 +277,9 @@ class ExLlamaV2MLP(ExLlamaV2Module): pass_loras, pass_lora_temp) + if cfg.arch.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + return hidden_states @@ -314,6 +319,9 @@ class ExLlamaV2MLP(ExLlamaV2Module): down = self.post_layernorm.forward(down) hidden_states = down + residual if self.has_residual else down + if cfg.arch.clamp_hidden_states: + hidden_states = hidden_states.clamp(-65504, 65504) + if intermediates: return {"post_norm": post_norm, "pre_down": y, diff --git a/exllamav2/rmsnorm.py b/exllamav2/rmsnorm.py index 68518c5..fa3b7e2 100644 --- a/exllamav2/rmsnorm.py +++ b/exllamav2/rmsnorm.py @@ -120,8 +120,7 @@ class ExLlamaV2RMSNorm(ExLlamaV2Module): loras = None, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: - hidden_states[hidden_states == -float('inf')] = -65504.0 - hidden_states[hidden_states == float('inf')] = 65504.0 + hidden_states.clamp_(-65504.0, 65504.0) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim = True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)