mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-20 23:03:58 +00:00
add fp16_fix
This commit is contained in:
@@ -9,6 +9,7 @@ import torch
|
||||
from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from backend.attention import attention_function
|
||||
from backend.utils import fp16_fix
|
||||
|
||||
|
||||
def attention(q, k, v, pe):
|
||||
@@ -242,6 +243,8 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * self.txt_norm2(txt) + txt_mod2_shift)
|
||||
del txt_mod2_gate, txt_mod2_scale, txt_mod2_shift
|
||||
|
||||
txt = fp16_fix(txt)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
@@ -279,7 +282,13 @@ class SingleStreamBlock(nn.Module):
|
||||
del q, k, v, pe
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2))
|
||||
del attn, mlp
|
||||
return x + mod_gate * output
|
||||
|
||||
x = x + mod_gate * output
|
||||
del mod_gate, output
|
||||
|
||||
x = fp16_fix(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
|
||||
@@ -75,3 +75,13 @@ def calculate_parameters(sd, prefix=""):
|
||||
if k.startswith(prefix):
|
||||
params += sd[k].nelement()
|
||||
return params
|
||||
|
||||
|
||||
def fp16_fix(x):
|
||||
# An interesting trick to avoid fp16 overflow
|
||||
# Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114
|
||||
# Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180
|
||||
|
||||
if x.dtype == torch.float16:
|
||||
return x.clip(-16384.0, 16384.0)
|
||||
return x
|
||||
|
||||
@@ -401,7 +401,7 @@ def prepare_environment():
|
||||
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
# k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "3f96b28763515dbe609792135df3615a440c66dc")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "84826248b49bb7ca754c73293299c4d4e23a548d")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user