From c00b45aa7161f71c214cdc19918248fab2af0331 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 7 Aug 2024 22:19:17 -0700 Subject: [PATCH] better cast --- backend/nn/flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 54800f96..67f888e4 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -105,7 +105,7 @@ class RMSNorm(nn.Module): to_args = dict(device=x.device, dtype=x.dtype) x = x.float() rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms.to(x) * self.scale.to(x)).to(**to_args) + return (x * rrms).to(**to_args) * self.scale.to(**to_args) class QKNorm(nn.Module):