mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
support pytorch 2.4 new normalization features
This commit is contained in:
@@ -9,7 +9,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from backend.attention import attention_function
|
from backend.attention import attention_function
|
||||||
from backend.utils import fp16_fix
|
from backend.utils import fp16_fix, tensor2parameter
|
||||||
|
|
||||||
|
|
||||||
def attention(q, k, v, pe):
|
def attention(q, k, v, pe):
|
||||||
@@ -98,16 +98,29 @@ class MLPEmbedder(nn.Module):
|
|||||||
return self.out_layer(x)
|
return self.out_layer(x)
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(torch, 'rms_norm'):
|
||||||
|
functional_rms_norm = torch.rms_norm
|
||||||
|
else:
|
||||||
|
def functional_rms_norm(x, normalized_shape, weight, eps):
|
||||||
|
if x.dtype in [torch.bfloat16, torch.float32]:
|
||||||
|
n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight
|
||||||
|
else:
|
||||||
|
n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight
|
||||||
|
return x * n
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.weight = None # to trigger module_profile
|
||||||
self.scale = nn.Parameter(torch.ones(dim))
|
self.scale = nn.Parameter(torch.ones(dim))
|
||||||
|
self.eps = 1e-6
|
||||||
|
self.normalized_shape = [dim]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
to_args = dict(device=x.device, dtype=x.dtype)
|
if self.scale.dtype != x.dtype:
|
||||||
x = x.float()
|
self.scale = tensor2parameter(self.scale.to(dtype=x.dtype))
|
||||||
rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6)
|
return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps)
|
||||||
return (x * rrms).to(**to_args) * self.scale.to(**to_args)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(nn.Module):
|
class QKNorm(nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user