diff --git a/modules_forge/controlnet.py b/modules_forge/controlnet.py index 8ad88311..b0ca32d1 100644 --- a/modules_forge/controlnet.py +++ b/modules_forge/controlnet.py @@ -15,7 +15,8 @@ def apply_controlnet_advanced( positive_advanced_weighting=None, negative_advanced_weighting=None, advanced_frame_weighting=None, - advanced_sigma_weighting=None + advanced_sigma_weighting=None, + advanced_mask_weighting=None ): """ @@ -53,6 +54,12 @@ def apply_controlnet_advanced( sigma_min = unet.model.model_sampling.sigma_min advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min) + # advanced_mask_weighting + + A mask can be applied to control signals. + This should be a tensor with shape B 1 H W where the H and W can be arbitrary. + This mask will be resized automatically to match the shape of all injection layers. + """ cnet = controlnet.copy().set_cond_hint(image_bchw, strength, (start_percent, end_percent)) @@ -61,6 +68,13 @@ def apply_controlnet_advanced( cnet.advanced_frame_weighting = advanced_frame_weighting cnet.advanced_sigma_weighting = advanced_sigma_weighting + if advanced_mask_weighting is not None: + assert isinstance(advanced_mask_weighting, torch.Tensor) + B, C, H, W = advanced_mask_weighting.shape + assert B > 0 and C == 1 and H > 0 and W > 0 + + cnet.advanced_mask_weighting = advanced_mask_weighting + m = unet.clone() m.add_patched_controlnet(cnet) return m @@ -72,10 +86,12 @@ def compute_controlnet_weighting(control, cnet): negative_advanced_weighting = cnet.negative_advanced_weighting advanced_frame_weighting = cnet.advanced_frame_weighting advanced_sigma_weighting = cnet.advanced_sigma_weighting + advanced_mask_weighting = cnet.advanced_mask_weighting transformer_options = cnet.transformer_options if positive_advanced_weighting is None and negative_advanced_weighting is None \ - and advanced_frame_weighting is None and advanced_sigma_weighting is None: + and advanced_frame_weighting is None and advanced_sigma_weighting is None \ + and advanced_mask_weighting is None: return control cond_or_uncond = transformer_options['cond_or_uncond'] @@ -92,6 +108,9 @@ def compute_controlnet_weighting(control, cnet): for k, v in control.items(): for i in range(len(v)): + control_signal = control[k][i] + B, C, H, W = control_signal.shape + positive_weight = 1.0 negative_weight = 1.0 sigma_weight = 1.0 @@ -112,6 +131,9 @@ def compute_controlnet_weighting(control, cnet): final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark final_weight = final_weight * sigma_weight * frame_weight - control[k][i] = control[k][i] * final_weight[:, None, None, None] + if isinstance(advanced_mask_weighting, torch.Tensor): + control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting, size=(H, W), mode='bilinear') + + control[k][i] = control_signal * final_weight[:, None, None, None] return control