Backend: better controlnet mask batch broadcasting

This commit is contained in:
lllyasviel
2024-02-06 20:13:09 -08:00
parent 1110183943
commit c185e39e59

View File

@@ -87,8 +87,10 @@ def compute_controlnet_weighting(control, cnet):
final_weight = final_weight * sigma_weight * frame_weight
if isinstance(advanced_mask_weighting, torch.Tensor):
if control_signal.shape[0] == 2 * advanced_mask_weighting.shape[0]:
advanced_mask_weighting = advanced_mask_weighting.repeat(2, 1, 1, 1)
if advanced_mask_weighting.shape[0] != 1:
k = int(control_signal.shape[0] // advanced_mask_weighting.shape[0])
if control_signal.shape[0] == k * advanced_mask_weighting.shape[0]:
advanced_mask_weighting = advanced_mask_weighting.repeat(k, 1, 1, 1)
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear')
control[k][i] = control_signal * final_weight[:, None, None, None]