From 31e057e7b32e6e7d1f564069ab7f797a8443e082 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sun, 28 Jan 2024 08:27:43 -0800 Subject: [PATCH] i --- modules_forge/controlnet.py | 2 +- modules_forge/forge_util.py | 11 +++++++++++ modules_forge/patch_basic.py | 4 ++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/modules_forge/controlnet.py b/modules_forge/controlnet.py index 48e924fa..4fa50ed8 100644 --- a/modules_forge/controlnet.py +++ b/modules_forge/controlnet.py @@ -73,7 +73,7 @@ def compute_controlnet_weighting( cond_or_uncond = transformer_options['cond_or_uncond'] sigmas = transformer_options['sigmas'] - cond_or_uncond_size = int(sigmas.shape[0]) + cond_mark = transformer_options['cond_mark'] if advanced_sigma_weighting is not None: advanced_sigma_weighting = advanced_sigma_weighting(sigmas) diff --git a/modules_forge/forge_util.py b/modules_forge/forge_util.py index b1afd9ab..4eae7de2 100644 --- a/modules_forge/forge_util.py +++ b/modules_forge/forge_util.py @@ -6,6 +6,17 @@ import random import string +def compute_cond_mark(cond_or_uncond, sigmas): + cond_or_uncond_size = int(sigmas.shape[0]) + + cond_mark = [] + for cx in cond_or_uncond: + cond_mark += [cx] * cond_or_uncond_size + + cond_mark = torch.Tensor(cond_mark).to(sigmas) + return cond_mark + + def generate_random_filename(extension=".txt"): timestamp = time.strftime("%Y%m%d-%H%M%S") random_string = ''.join(random.choices(string.ascii_lowercase + string.digits, k=5)) diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index d5ba3f89..9ca8e58b 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -5,6 +5,7 @@ from ldm_patched.modules.controlnet import ControlBase from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat from ldm_patched.modules import model_management from modules_forge.controlnet import compute_controlnet_weighting +from modules_forge.forge_util import compute_cond_mark def patched_control_merge(self, control_input, control_output, control_prev, output_dtype): @@ -151,6 +152,9 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["sigmas"] = timestep + cond_mark = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep) + transformer_options["cond_mark"] = cond_mark + c['transformer_options'] = transformer_options if control is not None: