Update forge_loader.py

This commit is contained in:
lllyasviel
2024-01-25 16:32:20 -08:00
parent a2a2b6023f
commit de5a75b130

View File

@@ -241,9 +241,19 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
sd_model.decode_first_stage = patched_decode_first_stage
sd_model.encode_first_stage = patched_encode_first_stage
sd_model.current_control_signals = {
'input': [],
'middle': [],
'output': []
}
original_forward = sd_model.model.diffusion_model.forward
def forge_unet_forward(*args, **kwargs):
kwargs.update(dict(
control=sd_model.current_control_signals,
transformer_options=sd_model.forge_objects.unet.model_options.get('transformer_options', {})
))
return original_forward(*args, **kwargs)
sd_model.model.diffusion_model.forward = forge_unet_forward