better cast

This commit is contained in:
layerdiffusion
2024-08-07 22:19:17 -07:00
parent 78c65708ea
commit c00b45aa71

View File

@@ -105,7 +105,7 @@ class RMSNorm(nn.Module):
to_args = dict(device=x.device, dtype=x.dtype)
x = x.float()
rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms.to(x) * self.scale.to(x)).to(**to_args)
return (x * rrms).to(**to_args) * self.scale.to(**to_args)
class QKNorm(nn.Module):