support controlnet mask in backend

This commit is contained in:
lllyasviel
2024-01-31 23:11:12 -08:00
parent 61891e096c
commit 0f7c71b400

View File

@@ -15,7 +15,8 @@ def apply_controlnet_advanced(
positive_advanced_weighting=None,
negative_advanced_weighting=None,
advanced_frame_weighting=None,
advanced_sigma_weighting=None
advanced_sigma_weighting=None,
advanced_mask_weighting=None
):
"""
@@ -53,6 +54,12 @@ def apply_controlnet_advanced(
sigma_min = unet.model.model_sampling.sigma_min
advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min)
# advanced_mask_weighting
A mask can be applied to control signals.
This should be a tensor with shape B 1 H W where the H and W can be arbitrary.
This mask will be resized automatically to match the shape of all injection layers.
"""
cnet = controlnet.copy().set_cond_hint(image_bchw, strength, (start_percent, end_percent))
@@ -61,6 +68,13 @@ def apply_controlnet_advanced(
cnet.advanced_frame_weighting = advanced_frame_weighting
cnet.advanced_sigma_weighting = advanced_sigma_weighting
if advanced_mask_weighting is not None:
assert isinstance(advanced_mask_weighting, torch.Tensor)
B, C, H, W = advanced_mask_weighting.shape
assert B > 0 and C == 1 and H > 0 and W > 0
cnet.advanced_mask_weighting = advanced_mask_weighting
m = unet.clone()
m.add_patched_controlnet(cnet)
return m
@@ -72,10 +86,12 @@ def compute_controlnet_weighting(control, cnet):
negative_advanced_weighting = cnet.negative_advanced_weighting
advanced_frame_weighting = cnet.advanced_frame_weighting
advanced_sigma_weighting = cnet.advanced_sigma_weighting
advanced_mask_weighting = cnet.advanced_mask_weighting
transformer_options = cnet.transformer_options
if positive_advanced_weighting is None and negative_advanced_weighting is None \
and advanced_frame_weighting is None and advanced_sigma_weighting is None:
and advanced_frame_weighting is None and advanced_sigma_weighting is None \
and advanced_mask_weighting is None:
return control
cond_or_uncond = transformer_options['cond_or_uncond']
@@ -92,6 +108,9 @@ def compute_controlnet_weighting(control, cnet):
for k, v in control.items():
for i in range(len(v)):
control_signal = control[k][i]
B, C, H, W = control_signal.shape
positive_weight = 1.0
negative_weight = 1.0
sigma_weight = 1.0
@@ -112,6 +131,9 @@ def compute_controlnet_weighting(control, cnet):
final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark
final_weight = final_weight * sigma_weight * frame_weight
control[k][i] = control[k][i] * final_weight[:, None, None, None]
if isinstance(advanced_mask_weighting, torch.Tensor):
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting, size=(H, W), mode='bilinear')
control[k][i] = control_signal * final_weight[:, None, None, None]
return control