advanced controlnet apply finished

This commit is contained in:
lllyasviel
2024-01-28 08:52:15 -08:00
parent 31e057e7b3
commit b2569c0183
2 changed files with 44 additions and 8 deletions

View File

@@ -105,9 +105,9 @@ class ControlNetExampleForge(scripts.Script):
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
}
negative_advanced_weighting = {
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
'middle': [1.0],
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
'input': [0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25],
'middle': [1.05],
'output': [0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25]
}
# The advanced_frame_weighting is a weight applied to each image in a batch.
@@ -125,10 +125,10 @@ class ControlNetExampleForge(scripts.Script):
advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min)
# But in this simple example we do not use them
# positive_advanced_weighting = None
# negative_advanced_weighting = None
# advanced_frame_weighting = None
# advanced_sigma_weighting = None
positive_advanced_weighting = None
negative_advanced_weighting = None
advanced_frame_weighting = None
advanced_sigma_weighting = None
unet = apply_controlnet_advanced(unet=unet, controlnet=self.model, image_bhwc=control_image,
strength=0.6, start_percent=0.0, end_percent=0.8,

View File

@@ -1,3 +1,10 @@
import torch
def get_at(array, index, default=None):
return array[index] if 0 <= index < len(array) else default
def apply_controlnet_advanced(
unet,
controlnet,
@@ -75,7 +82,36 @@ def compute_controlnet_weighting(
sigmas = transformer_options['sigmas']
cond_mark = transformer_options['cond_mark']
if advanced_frame_weighting is not None:
advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas)
assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \
'Frame weighting list length is different from batch size!'
if advanced_sigma_weighting is not None:
advanced_sigma_weighting = advanced_sigma_weighting(sigmas)
advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond))
for k, v in control.items():
for i in range(len(v)):
positive_weight = 1.0
negative_weight = 1.0
sigma_weight = 1.0
frame_weight = 1.0
if positive_advanced_weighting is not None:
positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0)
if negative_advanced_weighting is not None:
negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0)
if advanced_sigma_weighting is not None:
sigma_weight = advanced_sigma_weighting
if advanced_frame_weighting is not None:
frame_weight = advanced_frame_weighting
final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark
final_weight = final_weight * sigma_weight * frame_weight
control[k][i] = control[k][i] * final_weight[:, None, None, None]
return control