diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 1ab874cc..8c5b0f48 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -9,7 +9,7 @@ import torch from torch import nn from einops import rearrange, repeat from backend.attention import attention_function -from backend.utils import fp16_fix +from backend.utils import fp16_fix, tensor2parameter def attention(q, k, v, pe): @@ -98,16 +98,29 @@ class MLPEmbedder(nn.Module): return self.out_layer(x) +if hasattr(torch, 'rms_norm'): + functional_rms_norm = torch.rms_norm +else: + def functional_rms_norm(x, normalized_shape, weight, eps): + if x.dtype in [torch.bfloat16, torch.float32]: + n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight + else: + n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight + return x * n + + class RMSNorm(nn.Module): def __init__(self, dim): super().__init__() + self.weight = None # to trigger module_profile self.scale = nn.Parameter(torch.ones(dim)) + self.eps = 1e-6 + self.normalized_shape = [dim] def forward(self, x): - to_args = dict(device=x.device, dtype=x.dtype) - x = x.float() - rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms).to(**to_args) * self.scale.to(**to_args) + if self.scale.dtype != x.dtype: + self.scale = tensor2parameter(self.scale.to(dtype=x.dtype)) + return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps) class QKNorm(nn.Module):