mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-07 16:39:57 +00:00
backend
This commit is contained in:
@@ -11,9 +11,6 @@ import ldm_patched.controlnet.cldm
|
||||
import ldm_patched.t2ia.adapter
|
||||
|
||||
|
||||
compute_controlnet_weighting = None
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
@@ -32,6 +29,71 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
else:
|
||||
return torch.cat([tensor] * batched_number, dim=0)
|
||||
|
||||
|
||||
def get_at(array, index, default=None):
|
||||
return array[index] if 0 <= index < len(array) else default
|
||||
|
||||
|
||||
def compute_controlnet_weighting(control, cnet):
|
||||
|
||||
positive_advanced_weighting = getattr(cnet, 'positive_advanced_weighting', None)
|
||||
negative_advanced_weighting = getattr(cnet, 'negative_advanced_weighting', None)
|
||||
advanced_frame_weighting = getattr(cnet, 'advanced_frame_weighting', None)
|
||||
advanced_sigma_weighting = getattr(cnet, 'advanced_sigma_weighting', None)
|
||||
advanced_mask_weighting = getattr(cnet, 'advanced_mask_weighting', None)
|
||||
|
||||
transformer_options = cnet.transformer_options
|
||||
|
||||
if positive_advanced_weighting is None and negative_advanced_weighting is None \
|
||||
and advanced_frame_weighting is None and advanced_sigma_weighting is None \
|
||||
and advanced_mask_weighting is None:
|
||||
return control
|
||||
|
||||
cond_or_uncond = transformer_options['cond_or_uncond']
|
||||
sigmas = transformer_options['sigmas']
|
||||
cond_mark = transformer_options['cond_mark']
|
||||
|
||||
if advanced_frame_weighting is not None:
|
||||
advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas)
|
||||
assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \
|
||||
'Frame weighting list length is different from batch size!'
|
||||
|
||||
if advanced_sigma_weighting is not None:
|
||||
advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond))
|
||||
|
||||
for k, v in control.items():
|
||||
for i in range(len(v)):
|
||||
control_signal = control[k][i]
|
||||
B, C, H, W = control_signal.shape
|
||||
|
||||
positive_weight = 1.0
|
||||
negative_weight = 1.0
|
||||
sigma_weight = 1.0
|
||||
frame_weight = 1.0
|
||||
|
||||
if positive_advanced_weighting is not None:
|
||||
positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0)
|
||||
|
||||
if negative_advanced_weighting is not None:
|
||||
negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0)
|
||||
|
||||
if advanced_sigma_weighting is not None:
|
||||
sigma_weight = advanced_sigma_weighting
|
||||
|
||||
if advanced_frame_weighting is not None:
|
||||
frame_weight = advanced_frame_weighting
|
||||
|
||||
final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark
|
||||
final_weight = final_weight * sigma_weight * frame_weight
|
||||
|
||||
if isinstance(advanced_mask_weighting, torch.Tensor):
|
||||
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear')
|
||||
|
||||
control[k][i] = control_signal * final_weight[:, None, None, None]
|
||||
|
||||
return control
|
||||
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self, device=None):
|
||||
self.cond_hint_original = None
|
||||
@@ -40,6 +102,7 @@ class ControlBase:
|
||||
self.timestep_percent_range = (0.0, 1.0)
|
||||
self.global_average_pooling = False
|
||||
self.timestep_range = None
|
||||
self.transformer_options = {}
|
||||
|
||||
if device is None:
|
||||
device = ldm_patched.modules.model_management.get_torch_device()
|
||||
@@ -118,8 +181,7 @@ class ControlBase:
|
||||
|
||||
out[key].append(x)
|
||||
|
||||
if compute_controlnet_weighting is not None:
|
||||
out = compute_controlnet_weighting(out, self)
|
||||
out = compute_controlnet_weighting(out, self)
|
||||
|
||||
if control_prev is not None:
|
||||
for x in ['input', 'middle', 'output']:
|
||||
@@ -149,6 +211,11 @@ class ControlNet(ControlBase):
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
to = self.transformer_options
|
||||
|
||||
for conditioning_modifier in to.get('controlnet_conditioning_modifiers', []):
|
||||
x_noisy, t, cond, batched_number = conditioning_modifier(self, x_noisy, t, cond, batched_number)
|
||||
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
@@ -443,6 +510,11 @@ class T2IAdapter(ControlBase):
|
||||
return width, height
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
to = self.transformer_options
|
||||
|
||||
for conditioning_modifier in to.get('controlnet_conditioning_modifiers', []):
|
||||
x_noisy, t, cond, batched_number = conditioning_modifier(self, x_noisy, t, cond, batched_number)
|
||||
|
||||
control_prev = None
|
||||
if self.previous_controlnet is not None:
|
||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||
|
||||
@@ -244,7 +244,8 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
while p is not None:
|
||||
p.transformer_options = transformer_options
|
||||
p = p.previous_controlnet
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
control_cond = c.copy() # get_control may change items in this dict, so we need to copy it
|
||||
c['control'] = control.get_control(input_x, timestep_, control_cond, len(cond_or_uncond))
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
|
||||
Reference in New Issue
Block a user