From 48e9804ffb67a0c879cff5dcb146069217982674 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 16:22:29 -0800 Subject: [PATCH] Update forge_loader.py --- modules_forge/forge_loader.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index e69b3321..057294b1 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -241,8 +241,21 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model.decode_first_stage = patched_decode_first_stage sd_model.encode_first_stage = patched_encode_first_stage + patch_unet_forward(sd_model) + sd_model.clip = sd_model.cond_stage_model timer.record("forge finalize") sd_model.current_lora_hash = str([]) return sd_model + + +def patch_unet_forward(sd_model): + original_forward = sd_model.model.diffusion_model.forward + + def forge_unet_forward(self, *args, **kwargs): + return original_forward(self, *args, **kwargs) + + sd_model.model.diffusion_model.forward = forge_unet_forward + + return