diff --git a/README.md b/README.md index 6974569c..7d240720 100644 --- a/README.md +++ b/README.md @@ -462,10 +462,10 @@ class ControlNetExampleForge(scripts.Script): # The advanced_frame_weighting is a weight applied to each image in a batch. # The length of this list must be same with batch size - # For example, if batch size is 5, the below list is [0, 0.25, 0.5, 0.75, 1.0] + # For example, if batch size is 5, the below list is [0.2, 0.4, 0.6, 0.8, 1.0] # If you view the 5 images as 5 frames in a video, this will lead to # progressively stronger control over time. - advanced_frame_weighting = [float(i) / float(batch_size - 1) for i in range(batch_size)] + advanced_frame_weighting = [float(i + 1) / float(batch_size) for i in range(batch_size)] # The advanced_sigma_weighting allows you to dynamically compute control # weights given diffusion timestep (sigma). diff --git a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py index e85a72a3..6249d066 100644 --- a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py +++ b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py @@ -112,10 +112,10 @@ class ControlNetExampleForge(scripts.Script): # The advanced_frame_weighting is a weight applied to each image in a batch. # The length of this list must be same with batch size - # For example, if batch size is 5, the below list is [0, 0.25, 0.5, 0.75, 1.0] + # For example, if batch size is 5, the below list is [0.2, 0.4, 0.6, 0.8, 1.0] # If you view the 5 images as 5 frames in a video, this will lead to # progressively stronger control over time. - advanced_frame_weighting = [float(i) / float(batch_size - 1) for i in range(batch_size)] + advanced_frame_weighting = [float(i + 1) / float(batch_size) for i in range(batch_size)] # The advanced_sigma_weighting allows you to dynamically compute control # weights given diffusion timestep (sigma). @@ -125,10 +125,10 @@ class ControlNetExampleForge(scripts.Script): advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min) # But in this simple example we do not use them - positive_advanced_weighting = None - negative_advanced_weighting = None - advanced_frame_weighting = None - advanced_sigma_weighting = None + # positive_advanced_weighting = None + # negative_advanced_weighting = None + # advanced_frame_weighting = None + # advanced_sigma_weighting = None unet = apply_controlnet_advanced(unet=unet, controlnet=self.model, image_bhwc=control_image, strength=0.6, start_percent=0.0, end_percent=0.8, diff --git a/modules_forge/controlnet.py b/modules_forge/controlnet.py index 4f58d525..f88d76f3 100644 --- a/modules_forge/controlnet.py +++ b/modules_forge/controlnet.py @@ -57,3 +57,14 @@ def apply_controlnet_advanced( m = unet.clone() m.add_patched_controlnet(cnet) return m + + +def compute_controlnet_weighting( + control, + positive_advanced_weighting, + negative_advanced_weighting, + advanced_frame_weighting, + advanced_sigma_weighting, + transformer_options +): + return control diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index 504c9ddb..d5ba3f89 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -4,6 +4,7 @@ import ldm_patched.modules.samplers from ldm_patched.modules.controlnet import ControlBase from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat from ldm_patched.modules import model_management +from modules_forge.controlnet import compute_controlnet_weighting def patched_control_merge(self, control_input, control_output, control_prev, output_dtype): @@ -38,11 +39,14 @@ def patched_control_merge(self, control_input, control_output, control_prev, out out[key].append(x) - if self.positive_advanced_weighting is not None or self.negative_advanced_weighting: - # TODO: Implement here - cond_or_uncond = self.current_cond_or_uncond - a = 0 - pass + out = compute_controlnet_weighting( + out, + positive_advanced_weighting=self.positive_advanced_weighting, + negative_advanced_weighting=self.negative_advanced_weighting, + advanced_frame_weighting=self.advanced_frame_weighting, + advanced_sigma_weighting=self.advanced_sigma_weighting, + transformer_options=self.transformer_options + ) if control_prev is not None: for x in ['input', 'middle', 'output']: @@ -129,9 +133,6 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op c = cond_cat(c) timestep_ = torch.cat([timestep] * batch_chunks) - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - transformer_options = {} if 'transformer_options' in model_options: transformer_options = model_options['transformer_options'].copy() @@ -154,6 +155,7 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op if control is not None: control.transformer_options = transformer_options + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) if 'model_function_wrapper' in model_options: output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)