From 01e610f2ece1ee9efa5af2c2749bd78cf3194dbf Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Fri, 2 Feb 2024 14:32:00 -0800 Subject: [PATCH] backend --- ldm_patched/modules/controlnet.py | 82 +++++++++++++++++++++++++++++-- ldm_patched/modules/samplers.py | 3 +- 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 67d9c9ee..9ddd3c6b 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -11,9 +11,6 @@ import ldm_patched.controlnet.cldm import ldm_patched.t2ia.adapter -compute_controlnet_weighting = None - - def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] #print(current_batch_size, target_batch_size) @@ -32,6 +29,71 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): else: return torch.cat([tensor] * batched_number, dim=0) + +def get_at(array, index, default=None): + return array[index] if 0 <= index < len(array) else default + + +def compute_controlnet_weighting(control, cnet): + + positive_advanced_weighting = getattr(cnet, 'positive_advanced_weighting', None) + negative_advanced_weighting = getattr(cnet, 'negative_advanced_weighting', None) + advanced_frame_weighting = getattr(cnet, 'advanced_frame_weighting', None) + advanced_sigma_weighting = getattr(cnet, 'advanced_sigma_weighting', None) + advanced_mask_weighting = getattr(cnet, 'advanced_mask_weighting', None) + + transformer_options = cnet.transformer_options + + if positive_advanced_weighting is None and negative_advanced_weighting is None \ + and advanced_frame_weighting is None and advanced_sigma_weighting is None \ + and advanced_mask_weighting is None: + return control + + cond_or_uncond = transformer_options['cond_or_uncond'] + sigmas = transformer_options['sigmas'] + cond_mark = transformer_options['cond_mark'] + + if advanced_frame_weighting is not None: + advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas) + assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \ + 'Frame weighting list length is different from batch size!' + + if advanced_sigma_weighting is not None: + advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond)) + + for k, v in control.items(): + for i in range(len(v)): + control_signal = control[k][i] + B, C, H, W = control_signal.shape + + positive_weight = 1.0 + negative_weight = 1.0 + sigma_weight = 1.0 + frame_weight = 1.0 + + if positive_advanced_weighting is not None: + positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0) + + if negative_advanced_weighting is not None: + negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0) + + if advanced_sigma_weighting is not None: + sigma_weight = advanced_sigma_weighting + + if advanced_frame_weighting is not None: + frame_weight = advanced_frame_weighting + + final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark + final_weight = final_weight * sigma_weight * frame_weight + + if isinstance(advanced_mask_weighting, torch.Tensor): + control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear') + + control[k][i] = control_signal * final_weight[:, None, None, None] + + return control + + class ControlBase: def __init__(self, device=None): self.cond_hint_original = None @@ -40,6 +102,7 @@ class ControlBase: self.timestep_percent_range = (0.0, 1.0) self.global_average_pooling = False self.timestep_range = None + self.transformer_options = {} if device is None: device = ldm_patched.modules.model_management.get_torch_device() @@ -118,8 +181,7 @@ class ControlBase: out[key].append(x) - if compute_controlnet_weighting is not None: - out = compute_controlnet_weighting(out, self) + out = compute_controlnet_weighting(out, self) if control_prev is not None: for x in ['input', 'middle', 'output']: @@ -149,6 +211,11 @@ class ControlNet(ControlBase): self.manual_cast_dtype = manual_cast_dtype def get_control(self, x_noisy, t, cond, batched_number): + to = self.transformer_options + + for conditioning_modifier in to.get('controlnet_conditioning_modifiers', []): + x_noisy, t, cond, batched_number = conditioning_modifier(self, x_noisy, t, cond, batched_number) + control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) @@ -443,6 +510,11 @@ class T2IAdapter(ControlBase): return width, height def get_control(self, x_noisy, t, cond, batched_number): + to = self.transformer_options + + for conditioning_modifier in to.get('controlnet_conditioning_modifiers', []): + x_noisy, t, cond, batched_number = conditioning_modifier(self, x_noisy, t, cond, batched_number) + control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index d6f36b21..4c12380a 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -244,7 +244,8 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): while p is not None: p.transformer_options = transformer_options p = p.previous_controlnet - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) + control_cond = c.copy() # get_control may change items in this dict, so we need to copy it + c['control'] = control.get_control(input_x, timestep_, control_cond, len(cond_or_uncond)) if 'model_function_wrapper' in model_options: output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)