support controlnet_model_function_wrapper for t2i-adapter

This commit is contained in:
lllyasviel
2024-02-07 21:42:56 -08:00
parent 50035ad414
commit 4c9db26541

View File

@@ -557,7 +557,18 @@ class T2IAdapter(ControlBase):
if self.control_input is None:
self.t2i_model.to(x_noisy.dtype)
self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
controlnet_model_function_wrapper = to.get('controlnet_model_function_wrapper', None)
if controlnet_model_function_wrapper is not None:
wrapper_args = dict(hint=self.cond_hint.to(x_noisy.dtype))
wrapper_args['model'] = self
wrapper_args['inner_model'] = self.t2i_model
wrapper_args['inner_t2i_model'] = self.t2i_model
self.control_input = controlnet_model_function_wrapper(**wrapper_args)
else:
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu()
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))