add fp16_fix

This commit is contained in:
layerdiffusion
2024-08-14 17:10:03 -07:00
parent aadc0f04c4
commit b09c24ef51
3 changed files with 21 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ import torch
from torch import nn
from einops import rearrange, repeat
from backend.attention import attention_function
from backend.utils import fp16_fix
def attention(q, k, v, pe):
@@ -242,6 +243,8 @@ class DoubleStreamBlock(nn.Module):
txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * self.txt_norm2(txt) + txt_mod2_shift)
del txt_mod2_gate, txt_mod2_scale, txt_mod2_shift
txt = fp16_fix(txt)
return img, txt
@@ -279,7 +282,13 @@ class SingleStreamBlock(nn.Module):
del q, k, v, pe
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2))
del attn, mlp
return x + mod_gate * output
x = x + mod_gate * output
del mod_gate, output
x = fp16_fix(x)
return x
class LastLayer(nn.Module):

View File

@@ -75,3 +75,13 @@ def calculate_parameters(sd, prefix=""):
if k.startswith(prefix):
params += sd[k].nelement()
return params
def fp16_fix(x):
# An interesting trick to avoid fp16 overflow
# Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114
# Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180
if x.dtype == torch.float16:
return x.clip(-16384.0, 16384.0)
return x

View File

@@ -401,7 +401,7 @@ def prepare_environment():
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
# k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "3f96b28763515dbe609792135df3615a440c66dc")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "84826248b49bb7ca754c73293299c4d4e23a548d")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try: