support pytorch 2.4 new normalization features

This commit is contained in:
layerdiffusion
2024-08-28 09:08:26 -07:00
parent 0abb6c4686
commit 81d8f55bca

View File

@@ -9,7 +9,7 @@ import torch
from torch import nn
from einops import rearrange, repeat
from backend.attention import attention_function
from backend.utils import fp16_fix
from backend.utils import fp16_fix, tensor2parameter
def attention(q, k, v, pe):
@@ -98,16 +98,29 @@ class MLPEmbedder(nn.Module):
return self.out_layer(x)
if hasattr(torch, 'rms_norm'):
functional_rms_norm = torch.rms_norm
else:
def functional_rms_norm(x, normalized_shape, weight, eps):
if x.dtype in [torch.bfloat16, torch.float32]:
n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight
else:
n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight
return x * n
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = None # to trigger module_profile
self.scale = nn.Parameter(torch.ones(dim))
self.eps = 1e-6
self.normalized_shape = [dim]
def forward(self, x):
to_args = dict(device=x.device, dtype=x.dtype)
x = x.float()
rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(**to_args) * self.scale.to(**to_args)
if self.scale.dtype != x.dtype:
self.scale = tensor2parameter(self.scale.to(dtype=x.dtype))
return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps)
class QKNorm(nn.Module):