add fp16_fix

This commit is contained in:
layerdiffusion
2024-08-14 17:10:03 -07:00
parent aadc0f04c4
commit b09c24ef51
3 changed files with 21 additions and 2 deletions

View File

@@ -75,3 +75,13 @@ def calculate_parameters(sd, prefix=""):
if k.startswith(prefix):
params += sd[k].nelement()
return params
def fp16_fix(x):
# An interesting trick to avoid fp16 overflow
# Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114
# Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180
if x.dtype == torch.float16:
return x.clip(-16384.0, 16384.0)
return x