diff --git a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py index e138fb5c..9a41ddb5 100644 --- a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py +++ b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py @@ -105,9 +105,9 @@ class ControlNetExampleForge(scripts.Script): 'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2] } negative_advanced_weighting = { - 'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2], - 'middle': [1.0], - 'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2] + 'input': [0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25], + 'middle': [1.05], + 'output': [0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25] } # The advanced_frame_weighting is a weight applied to each image in a batch. @@ -125,10 +125,10 @@ class ControlNetExampleForge(scripts.Script): advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min) # But in this simple example we do not use them - # positive_advanced_weighting = None - # negative_advanced_weighting = None - # advanced_frame_weighting = None - # advanced_sigma_weighting = None + positive_advanced_weighting = None + negative_advanced_weighting = None + advanced_frame_weighting = None + advanced_sigma_weighting = None unet = apply_controlnet_advanced(unet=unet, controlnet=self.model, image_bhwc=control_image, strength=0.6, start_percent=0.0, end_percent=0.8, diff --git a/modules_forge/controlnet.py b/modules_forge/controlnet.py index 4fa50ed8..3d4959b6 100644 --- a/modules_forge/controlnet.py +++ b/modules_forge/controlnet.py @@ -1,3 +1,10 @@ +import torch + + +def get_at(array, index, default=None): + return array[index] if 0 <= index < len(array) else default + + def apply_controlnet_advanced( unet, controlnet, @@ -75,7 +82,36 @@ def compute_controlnet_weighting( sigmas = transformer_options['sigmas'] cond_mark = transformer_options['cond_mark'] + if advanced_frame_weighting is not None: + advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas) + assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \ + 'Frame weighting list length is different from batch size!' + if advanced_sigma_weighting is not None: - advanced_sigma_weighting = advanced_sigma_weighting(sigmas) + advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond)) + + for k, v in control.items(): + for i in range(len(v)): + positive_weight = 1.0 + negative_weight = 1.0 + sigma_weight = 1.0 + frame_weight = 1.0 + + if positive_advanced_weighting is not None: + positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0) + + if negative_advanced_weighting is not None: + negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0) + + if advanced_sigma_weighting is not None: + sigma_weight = advanced_sigma_weighting + + if advanced_frame_weighting is not None: + frame_weight = advanced_frame_weighting + + final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark + final_weight = final_weight * sigma_weight * frame_weight + + control[k][i] = control[k][i] * final_weight[:, None, None, None] return control