Allow controlnet_model_function_wrapper

for animatediff to manage controlnet batch slicing window
This commit is contained in:
lllyasviel
2024-02-07 20:06:23 -08:00
parent 383aaca1eb
commit a1670c536d
2 changed files with 20 additions and 1 deletions

View File

@@ -255,7 +255,15 @@ class ControlNet(ControlBase):
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
controlnet_model_function_wrapper = to.get('controlnet_model_function_wrapper', None)
if controlnet_model_function_wrapper is not None:
wrapper_args = dict(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
wrapper_args['model'] = self
wrapper_args['inner_model'] = self.control_model
control = controlnet_model_function_wrapper(**wrapper_args)
else:
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):

View File

@@ -98,6 +98,13 @@ class UnetPatcher(ModelPatcher):
to[k].append(v)
return
def set_transformer_option(self, k, v):
if 'transformer_options' not in self.model_options:
self.model_options['transformer_options'] = {}
self.model_options['transformer_options'][k] = v
return
def add_conditioning_modifier(self, modifier, ensure_uniqueness=False):
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
return
@@ -150,6 +157,10 @@ class UnetPatcher(ModelPatcher):
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
return
def set_controlnet_model_function_wrapper(self, wrapper):
self.set_transformer_option('controlnet_model_function_wrapper', wrapper)
return
def set_model_replace_all(self, patch, target="attn1"):
for block_name in ['input', 'middle', 'output']:
for number in range(16):