mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 03:01:15 +00:00
better cast
This commit is contained in:
@@ -105,7 +105,7 @@ class RMSNorm(nn.Module):
|
|||||||
to_args = dict(device=x.device, dtype=x.dtype)
|
to_args = dict(device=x.device, dtype=x.dtype)
|
||||||
x = x.float()
|
x = x.float()
|
||||||
rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6)
|
rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6)
|
||||||
return (x * rrms.to(x) * self.scale.to(x)).to(**to_args)
|
return (x * rrms).to(**to_args) * self.scale.to(**to_args)
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(nn.Module):
|
class QKNorm(nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user