mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-10 09:59:57 +00:00
Allow controlnet_model_function_wrapper
for animatediff to manage controlnet batch slicing window
This commit is contained in:
@@ -255,7 +255,15 @@ class ControlNet(ControlBase):
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
controlnet_model_function_wrapper = to.get('controlnet_model_function_wrapper', None)
|
||||
|
||||
if controlnet_model_function_wrapper is not None:
|
||||
wrapper_args = dict(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
wrapper_args['model'] = self
|
||||
wrapper_args['inner_model'] = self.control_model
|
||||
control = controlnet_model_function_wrapper(**wrapper_args)
|
||||
else:
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
|
||||
@@ -98,6 +98,13 @@ class UnetPatcher(ModelPatcher):
|
||||
to[k].append(v)
|
||||
return
|
||||
|
||||
def set_transformer_option(self, k, v):
|
||||
if 'transformer_options' not in self.model_options:
|
||||
self.model_options['transformer_options'] = {}
|
||||
|
||||
self.model_options['transformer_options'][k] = v
|
||||
return
|
||||
|
||||
def add_conditioning_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
@@ -150,6 +157,10 @@ class UnetPatcher(ModelPatcher):
|
||||
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def set_controlnet_model_function_wrapper(self, wrapper):
|
||||
self.set_transformer_option('controlnet_model_function_wrapper', wrapper)
|
||||
return
|
||||
|
||||
def set_model_replace_all(self, patch, target="attn1"):
|
||||
for block_name in ['input', 'middle', 'output']:
|
||||
for number in range(16):
|
||||
|
||||
Reference in New Issue
Block a user