mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
support pytorch 2.4 new normalization features
This commit is contained in:
@@ -9,7 +9,7 @@ import torch
|
||||
from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from backend.attention import attention_function
|
||||
from backend.utils import fp16_fix
|
||||
from backend.utils import fp16_fix, tensor2parameter
|
||||
|
||||
|
||||
def attention(q, k, v, pe):
|
||||
@@ -98,16 +98,29 @@ class MLPEmbedder(nn.Module):
|
||||
return self.out_layer(x)
|
||||
|
||||
|
||||
if hasattr(torch, 'rms_norm'):
|
||||
functional_rms_norm = torch.rms_norm
|
||||
else:
|
||||
def functional_rms_norm(x, normalized_shape, weight, eps):
|
||||
if x.dtype in [torch.bfloat16, torch.float32]:
|
||||
n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight
|
||||
else:
|
||||
n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight
|
||||
return x * n
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight = None # to trigger module_profile
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
self.eps = 1e-6
|
||||
self.normalized_shape = [dim]
|
||||
|
||||
def forward(self, x):
|
||||
to_args = dict(device=x.device, dtype=x.dtype)
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(**to_args) * self.scale.to(**to_args)
|
||||
if self.scale.dtype != x.dtype:
|
||||
self.scale = tensor2parameter(self.scale.to(dtype=x.dtype))
|
||||
return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps)
|
||||
|
||||
|
||||
class QKNorm(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user