diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index d2a2197e..820b715d 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -251,9 +251,13 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): original_forward = sd_model.model.diffusion_model.forward def forge_unet_forward(*args, **kwargs): + current_transformer_options = kwargs.get('transformer_options', {}) + current_transformer_options.update(dict(cond_or_uncond=sd_model.cond_or_uncond)) + current_transformer_options.update(sd_model.forge_objects.unet.model_options.get('transformer_options', {})) + kwargs.update(dict( control=sd_model.current_controlnet_signals, - transformer_options=sd_model.forge_objects.unet.model_options.get('transformer_options', {}) + transformer_options=current_transformer_options )) return original_forward(*args, **kwargs)