diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 41fa2b1c..19234c1c 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -557,7 +557,18 @@ class T2IAdapter(ControlBase): if self.control_input is None: self.t2i_model.to(x_noisy.dtype) self.t2i_model.to(self.device) - self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) + + controlnet_model_function_wrapper = to.get('controlnet_model_function_wrapper', None) + + if controlnet_model_function_wrapper is not None: + wrapper_args = dict(hint=self.cond_hint.to(x_noisy.dtype)) + wrapper_args['model'] = self + wrapper_args['inner_model'] = self.t2i_model + wrapper_args['inner_t2i_model'] = self.t2i_model + self.control_input = controlnet_model_function_wrapper(**wrapper_args) + else: + self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) + self.t2i_model.cpu() control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))