From b09c24ef5156674a7a26c605aee761f36e31225e Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 14 Aug 2024 17:10:03 -0700 Subject: [PATCH] add fp16_fix --- backend/nn/flux.py | 11 ++++++++++- backend/utils.py | 10 ++++++++++ modules/launch_utils.py | 2 +- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/backend/nn/flux.py b/backend/nn/flux.py index c1514082..1ab874cc 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -9,6 +9,7 @@ import torch from torch import nn from einops import rearrange, repeat from backend.attention import attention_function +from backend.utils import fp16_fix def attention(q, k, v, pe): @@ -242,6 +243,8 @@ class DoubleStreamBlock(nn.Module): txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * self.txt_norm2(txt) + txt_mod2_shift) del txt_mod2_gate, txt_mod2_scale, txt_mod2_shift + txt = fp16_fix(txt) + return img, txt @@ -279,7 +282,13 @@ class SingleStreamBlock(nn.Module): del q, k, v, pe output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2)) del attn, mlp - return x + mod_gate * output + + x = x + mod_gate * output + del mod_gate, output + + x = fp16_fix(x) + + return x class LastLayer(nn.Module): diff --git a/backend/utils.py b/backend/utils.py index de13f680..335ffa73 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -75,3 +75,13 @@ def calculate_parameters(sd, prefix=""): if k.startswith(prefix): params += sd[k].nelement() return params + + +def fp16_fix(x): + # An interesting trick to avoid fp16 overflow + # Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114 + # Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180 + + if x.dtype == torch.float16: + return x.clip(-16384.0, 16384.0) + return x diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 17202876..9e323708 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -401,7 +401,7 @@ def prepare_environment(): # stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") # stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") # k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") - huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "3f96b28763515dbe609792135df3615a440c66dc") + huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "84826248b49bb7ca754c73293299c4d4e23a548d") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") try: