This commit is contained in:
lllyasviel
2024-01-28 07:56:42 -08:00
parent 8158e31d80
commit 3d6d19a893
4 changed files with 29 additions and 16 deletions

View File

@@ -4,6 +4,7 @@ import ldm_patched.modules.samplers
from ldm_patched.modules.controlnet import ControlBase
from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat
from ldm_patched.modules import model_management
from modules_forge.controlnet import compute_controlnet_weighting
def patched_control_merge(self, control_input, control_output, control_prev, output_dtype):
@@ -38,11 +39,14 @@ def patched_control_merge(self, control_input, control_output, control_prev, out
out[key].append(x)
if self.positive_advanced_weighting is not None or self.negative_advanced_weighting:
# TODO: Implement here
cond_or_uncond = self.current_cond_or_uncond
a = 0
pass
out = compute_controlnet_weighting(
out,
positive_advanced_weighting=self.positive_advanced_weighting,
negative_advanced_weighting=self.negative_advanced_weighting,
advanced_frame_weighting=self.advanced_frame_weighting,
advanced_sigma_weighting=self.advanced_sigma_weighting,
transformer_options=self.transformer_options
)
if control_prev is not None:
for x in ['input', 'middle', 'output']:
@@ -129,9 +133,6 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
@@ -154,6 +155,7 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op
if control is not None:
control.transformer_options = transformer_options
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)