From c2e2d0ea9be84858b8e7ad09b171267d72e163ed Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 30 Jan 2024 20:27:06 -0800 Subject: [PATCH] i --- .../scripts/forge_reference.py | 37 ++++++++++++++++--- modules_forge/forge_util.py | 14 +++++++ modules_forge/patch_basic.py | 6 +-- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index b77eba86..6f23cefd 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -82,7 +82,8 @@ class PreprocessorReference(Preprocessor): self.recorded_h[location] = h return h else: - cond_mark = transformer_options['cond_mark'][:, None, None, None] # cond is 0 + cond_indices = transformer_options['cond_indices'] + uncond_indices = transformer_options['uncond_indices'] rh = self.recorded_h[location] return h @@ -101,11 +102,35 @@ class PreprocessorReference(Preprocessor): self.recorded_attn1[location] = (k, v) return sdp(q, k, v, transformer_options) else: - cond_mark = transformer_options['cond_mark'][:, None, None, None] # cond is 0 - rk, rv = self.recorded_attn1[location] - rk = torch.cat([k, rk], dim=1) - rv = torch.cat([v, rv], dim=1) - return sdp(q, k, v, transformer_options) + cond_indices = transformer_options['cond_indices'] + uncond_indices = transformer_options['uncond_indices'] + cond_or_uncond = transformer_options['cond_or_uncond'] + + q_c = q[cond_indices] + q_uc = q[uncond_indices] + + k_c = k[cond_indices] + k_uc = k[uncond_indices] + + v_c = v[cond_indices] + v_uc = v[uncond_indices] + + k_r, v_r = self.recorded_attn1[location] + + o_c = sdp(q_c, torch.cat([k_c, k_r], dim=1), torch.cat([v_c, v_r], dim=1), transformer_options) + o_uc_strong = sdp(q_uc, k_uc, v_uc, transformer_options) + o_uc_weak = sdp(q_uc, torch.cat([k_uc, k_r], dim=1), torch.cat([v_uc, v_r], dim=1), transformer_options) + o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity + + recon = [] + for cx in cond_or_uncond: + if cx == 0: + recon.append(o_c) + else: + recon.append(o_uc) + + o = torch.cat(recon, dim=0) + return o unet.add_block_modifier(block_proc) unet.add_conditioning_modifier(conditioning_modifier) diff --git a/modules_forge/forge_util.py b/modules_forge/forge_util.py index ba1e8bb5..7a2c7688 100644 --- a/modules_forge/forge_util.py +++ b/modules_forge/forge_util.py @@ -39,6 +39,20 @@ def compute_cond_mark(cond_or_uncond, sigmas): return cond_mark +def compute_cond_indices(cond_or_uncond, sigmas): + cl = int(sigmas.shape[0]) + + cond_indices = [] + uncond_indices = [] + for i, cx in enumerate(cond_or_uncond): + if cx == 0: + cond_indices += list(range(i * cl, (i + 1) * cl)) + else: + uncond_indices += list(range(i * cl, (i + 1) * cl)) + + return cond_indices, uncond_indices + + def generate_random_filename(extension=".txt"): timestamp = time.strftime("%Y%m%d-%H%M%S") random_string = ''.join(random.choices(string.ascii_lowercase + string.digits, k=5)) diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index f215bb46..45b78d72 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -8,7 +8,7 @@ from ldm_patched.modules.controlnet import ControlBase from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat from ldm_patched.modules import model_management from modules_forge.controlnet import compute_controlnet_weighting -from modules_forge.forge_util import compute_cond_mark +from modules_forge.forge_util import compute_cond_mark, compute_cond_indices def patched_control_merge(self, control_input, control_output, control_prev, output_dtype): @@ -155,8 +155,8 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["sigmas"] = timestep - cond_mark = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep) - transformer_options["cond_mark"] = cond_mark + transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep) + transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep) c['transformer_options'] = transformer_options