Update forge_loader.py

This commit is contained in:
lllyasviel
2024-01-25 16:22:29 -08:00
parent abf33dbb43
commit 48e9804ffb

View File

@@ -241,8 +241,21 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
sd_model.decode_first_stage = patched_decode_first_stage
sd_model.encode_first_stage = patched_encode_first_stage
patch_unet_forward(sd_model)
sd_model.clip = sd_model.cond_stage_model
timer.record("forge finalize")
sd_model.current_lora_hash = str([])
return sd_model
def patch_unet_forward(sd_model):
original_forward = sd_model.model.diffusion_model.forward
def forge_unet_forward(self, *args, **kwargs):
return original_forward(self, *args, **kwargs)
sd_model.model.diffusion_model.forward = forge_unet_forward
return