mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-24 08:19:13 +00:00
qol
This commit is contained in:
@@ -1,10 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_at(array, index, default=None):
|
||||
return array[index] if 0 <= index < len(array) else default
|
||||
|
||||
|
||||
def apply_controlnet_advanced(
|
||||
unet,
|
||||
controlnet,
|
||||
@@ -79,61 +75,3 @@ def apply_controlnet_advanced(
|
||||
m.add_patched_controlnet(cnet)
|
||||
return m
|
||||
|
||||
|
||||
def compute_controlnet_weighting(control, cnet):
|
||||
|
||||
positive_advanced_weighting = cnet.positive_advanced_weighting
|
||||
negative_advanced_weighting = cnet.negative_advanced_weighting
|
||||
advanced_frame_weighting = cnet.advanced_frame_weighting
|
||||
advanced_sigma_weighting = cnet.advanced_sigma_weighting
|
||||
advanced_mask_weighting = cnet.advanced_mask_weighting
|
||||
transformer_options = cnet.transformer_options
|
||||
|
||||
if positive_advanced_weighting is None and negative_advanced_weighting is None \
|
||||
and advanced_frame_weighting is None and advanced_sigma_weighting is None \
|
||||
and advanced_mask_weighting is None:
|
||||
return control
|
||||
|
||||
cond_or_uncond = transformer_options['cond_or_uncond']
|
||||
sigmas = transformer_options['sigmas']
|
||||
cond_mark = transformer_options['cond_mark']
|
||||
|
||||
if advanced_frame_weighting is not None:
|
||||
advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas)
|
||||
assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \
|
||||
'Frame weighting list length is different from batch size!'
|
||||
|
||||
if advanced_sigma_weighting is not None:
|
||||
advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond))
|
||||
|
||||
for k, v in control.items():
|
||||
for i in range(len(v)):
|
||||
control_signal = control[k][i]
|
||||
B, C, H, W = control_signal.shape
|
||||
|
||||
positive_weight = 1.0
|
||||
negative_weight = 1.0
|
||||
sigma_weight = 1.0
|
||||
frame_weight = 1.0
|
||||
|
||||
if positive_advanced_weighting is not None:
|
||||
positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0)
|
||||
|
||||
if negative_advanced_weighting is not None:
|
||||
negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0)
|
||||
|
||||
if advanced_sigma_weighting is not None:
|
||||
sigma_weight = advanced_sigma_weighting
|
||||
|
||||
if advanced_frame_weighting is not None:
|
||||
frame_weight = advanced_frame_weighting
|
||||
|
||||
final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark
|
||||
final_weight = final_weight * sigma_weight * frame_weight
|
||||
|
||||
if isinstance(advanced_mask_weighting, torch.Tensor):
|
||||
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting, size=(H, W), mode='bilinear')
|
||||
|
||||
control[k][i] = control_signal * final_weight[:, None, None, None]
|
||||
|
||||
return control
|
||||
|
||||
Reference in New Issue
Block a user