diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 9414f8e7..3cd4d141 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -87,8 +87,10 @@ def compute_controlnet_weighting(control, cnet): final_weight = final_weight * sigma_weight * frame_weight if isinstance(advanced_mask_weighting, torch.Tensor): - if control_signal.shape[0] == 2 * advanced_mask_weighting.shape[0]: - advanced_mask_weighting = advanced_mask_weighting.repeat(2, 1, 1, 1) + if advanced_mask_weighting.shape[0] != 1: + k = int(control_signal.shape[0] // advanced_mask_weighting.shape[0]) + if control_signal.shape[0] == k * advanced_mask_weighting.shape[0]: + advanced_mask_weighting = advanced_mask_weighting.repeat(k, 1, 1, 1) control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear') control[k][i] = control_signal * final_weight[:, None, None, None]