This commit is contained in:
lllyasviel
2024-01-28 07:56:42 -08:00
parent 8158e31d80
commit 3d6d19a893
4 changed files with 29 additions and 16 deletions

View File

@@ -462,10 +462,10 @@ class ControlNetExampleForge(scripts.Script):
# The advanced_frame_weighting is a weight applied to each image in a batch.
# The length of this list must be same with batch size
# For example, if batch size is 5, the below list is [0, 0.25, 0.5, 0.75, 1.0]
# For example, if batch size is 5, the below list is [0.2, 0.4, 0.6, 0.8, 1.0]
# If you view the 5 images as 5 frames in a video, this will lead to
# progressively stronger control over time.
advanced_frame_weighting = [float(i) / float(batch_size - 1) for i in range(batch_size)]
advanced_frame_weighting = [float(i + 1) / float(batch_size) for i in range(batch_size)]
# The advanced_sigma_weighting allows you to dynamically compute control
# weights given diffusion timestep (sigma).

View File

@@ -112,10 +112,10 @@ class ControlNetExampleForge(scripts.Script):
# The advanced_frame_weighting is a weight applied to each image in a batch.
# The length of this list must be same with batch size
# For example, if batch size is 5, the below list is [0, 0.25, 0.5, 0.75, 1.0]
# For example, if batch size is 5, the below list is [0.2, 0.4, 0.6, 0.8, 1.0]
# If you view the 5 images as 5 frames in a video, this will lead to
# progressively stronger control over time.
advanced_frame_weighting = [float(i) / float(batch_size - 1) for i in range(batch_size)]
advanced_frame_weighting = [float(i + 1) / float(batch_size) for i in range(batch_size)]
# The advanced_sigma_weighting allows you to dynamically compute control
# weights given diffusion timestep (sigma).
@@ -125,10 +125,10 @@ class ControlNetExampleForge(scripts.Script):
advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min)
# But in this simple example we do not use them
positive_advanced_weighting = None
negative_advanced_weighting = None
advanced_frame_weighting = None
advanced_sigma_weighting = None
# positive_advanced_weighting = None
# negative_advanced_weighting = None
# advanced_frame_weighting = None
# advanced_sigma_weighting = None
unet = apply_controlnet_advanced(unet=unet, controlnet=self.model, image_bhwc=control_image,
strength=0.6, start_percent=0.0, end_percent=0.8,

View File

@@ -57,3 +57,14 @@ def apply_controlnet_advanced(
m = unet.clone()
m.add_patched_controlnet(cnet)
return m
def compute_controlnet_weighting(
control,
positive_advanced_weighting,
negative_advanced_weighting,
advanced_frame_weighting,
advanced_sigma_weighting,
transformer_options
):
return control

View File

@@ -4,6 +4,7 @@ import ldm_patched.modules.samplers
from ldm_patched.modules.controlnet import ControlBase
from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat
from ldm_patched.modules import model_management
from modules_forge.controlnet import compute_controlnet_weighting
def patched_control_merge(self, control_input, control_output, control_prev, output_dtype):
@@ -38,11 +39,14 @@ def patched_control_merge(self, control_input, control_output, control_prev, out
out[key].append(x)
if self.positive_advanced_weighting is not None or self.negative_advanced_weighting:
# TODO: Implement here
cond_or_uncond = self.current_cond_or_uncond
a = 0
pass
out = compute_controlnet_weighting(
out,
positive_advanced_weighting=self.positive_advanced_weighting,
negative_advanced_weighting=self.negative_advanced_weighting,
advanced_frame_weighting=self.advanced_frame_weighting,
advanced_sigma_weighting=self.advanced_sigma_weighting,
transformer_options=self.transformer_options
)
if control_prev is not None:
for x in ['input', 'middle', 'output']:
@@ -129,9 +133,6 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
@@ -154,6 +155,7 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op
if control is not None:
control.transformer_options = transformer_options
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)