diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index e25a4779..3bcc5da0 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -255,7 +255,15 @@ class ControlNet(ControlBase): timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) + controlnet_model_function_wrapper = to.get('controlnet_model_function_wrapper', None) + + if controlnet_model_function_wrapper is not None: + wrapper_args = dict(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) + wrapper_args['model'] = self + wrapper_args['inner_model'] = self.control_model + control = controlnet_model_function_wrapper(**wrapper_args) + else: + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index a46c03d2..6c0f0e61 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -98,6 +98,13 @@ class UnetPatcher(ModelPatcher): to[k].append(v) return + def set_transformer_option(self, k, v): + if 'transformer_options' not in self.model_options: + self.model_options['transformer_options'] = {} + + self.model_options['transformer_options'][k] = v + return + def add_conditioning_modifier(self, modifier, ensure_uniqueness=False): self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness) return @@ -150,6 +157,10 @@ class UnetPatcher(ModelPatcher): self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness) return + def set_controlnet_model_function_wrapper(self, wrapper): + self.set_transformer_option('controlnet_model_function_wrapper', wrapper) + return + def set_model_replace_all(self, patch, target="attn1"): for block_name in ['input', 'middle', 'output']: for number in range(16):