From 4c9db26541bdca195cff89977dfbf601f8193ac6 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Wed, 7 Feb 2024 21:42:56 -0800 Subject: [PATCH] support controlnet_model_function_wrapper for t2i-adapter --- ldm_patched/modules/controlnet.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 41fa2b1c..19234c1c 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -557,7 +557,18 @@ class T2IAdapter(ControlBase): if self.control_input is None: self.t2i_model.to(x_noisy.dtype) self.t2i_model.to(self.device) - self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) + + controlnet_model_function_wrapper = to.get('controlnet_model_function_wrapper', None) + + if controlnet_model_function_wrapper is not None: + wrapper_args = dict(hint=self.cond_hint.to(x_noisy.dtype)) + wrapper_args['model'] = self + wrapper_args['inner_model'] = self.t2i_model + wrapper_args['inner_t2i_model'] = self.t2i_model + self.control_input = controlnet_model_function_wrapper(**wrapper_args) + else: + self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) + self.t2i_model.cpu() control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))