From c185e39e59e08c7aaf507e03bdc59b06570e450b Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 6 Feb 2024 20:13:09 -0800 Subject: [PATCH] Backend: better controlnet mask batch broadcasting --- ldm_patched/modules/controlnet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 9414f8e7..3cd4d141 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -87,8 +87,10 @@ def compute_controlnet_weighting(control, cnet): final_weight = final_weight * sigma_weight * frame_weight if isinstance(advanced_mask_weighting, torch.Tensor): - if control_signal.shape[0] == 2 * advanced_mask_weighting.shape[0]: - advanced_mask_weighting = advanced_mask_weighting.repeat(2, 1, 1, 1) + if advanced_mask_weighting.shape[0] != 1: + k = int(control_signal.shape[0] // advanced_mask_weighting.shape[0]) + if control_signal.shape[0] == k * advanced_mask_weighting.shape[0]: + advanced_mask_weighting = advanced_mask_weighting.repeat(k, 1, 1, 1) control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear') control[k][i] = control_signal * final_weight[:, None, None, None]