Fix FP32 residual for paged attn

This commit is contained in:
turboderp
2025-03-14 23:09:31 +01:00
parent d8fa1a8250
commit 23395dfa42

View File

@@ -628,7 +628,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
else:
hidden_states = self.o_proj.forward(attn_output, loras = loras)
if self.post_layernorm:
hidden_states = self.post_layernorm.forward(hidden_states)
hidden_states = self.post_layernorm.forward(hidden_states, output_fp32 = self.archparams.residual_stream_fp32)
if self.has_residual:
hidden_states += residual