From a1670c536d5248f2a03b0dae30bb7b10281cdd2f Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Wed, 7 Feb 2024 20:06:23 -0800 Subject: [PATCH] Allow controlnet_model_function_wrapper for animatediff to manage controlnet batch slicing window --- ldm_patched/modules/controlnet.py | 10 +++++++++- modules_forge/unet_patcher.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index e25a4779..3bcc5da0 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -255,7 +255,15 @@ class ControlNet(ControlBase): timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) + controlnet_model_function_wrapper = to.get('controlnet_model_function_wrapper', None) + + if controlnet_model_function_wrapper is not None: + wrapper_args = dict(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) + wrapper_args['model'] = self + wrapper_args['inner_model'] = self.control_model + control = controlnet_model_function_wrapper(**wrapper_args) + else: + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index a46c03d2..6c0f0e61 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -98,6 +98,13 @@ class UnetPatcher(ModelPatcher): to[k].append(v) return + def set_transformer_option(self, k, v): + if 'transformer_options' not in self.model_options: + self.model_options['transformer_options'] = {} + + self.model_options['transformer_options'][k] = v + return + def add_conditioning_modifier(self, modifier, ensure_uniqueness=False): self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness) return @@ -150,6 +157,10 @@ class UnetPatcher(ModelPatcher): self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness) return + def set_controlnet_model_function_wrapper(self, wrapper): + self.set_transformer_option('controlnet_model_function_wrapper', wrapper) + return + def set_model_replace_all(self, patch, target="attn1"): for block_name in ['input', 'middle', 'output']: for number in range(16):