From ed495d72867376c4e46ad602da89d2ede3f42db2 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 20:22:04 -0800 Subject: [PATCH] i --- modules/sd_models_xl.py | 4 ++-- modules/sd_samplers_cfg_denoiser.py | 18 +++++++++++++----- modules_forge/forge_util.py | 17 +++++++++++++++++ 3 files changed, 32 insertions(+), 7 deletions(-) create mode 100644 modules_forge/forge_util.py diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 686aec99..30cd9c23 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -38,11 +38,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: return c -def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): +def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs): if self.model.diffusion_model.in_channels == 9: x = torch.cat([x] + cond['c_concat'], dim=1) - return self.model(x, t, cond) + return self.model(x, t, cond, *args, **kwargs) def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index fb2c2834..f5f7a815 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -77,10 +77,13 @@ class CFGDenoiser(torch.nn.Module): if "sampler_cfg_function" in model_options or "sampler_post_cfg_function" in model_options: cond_scale = float(cond_scale) - model = self.inner_model.inner_model.forge_objects.unet + model = self.inner_model.inner_model.forge_objects.unet.model x = x_in[-uncond.shape[0]:] uncond_pred = denoised_uncond cond_pred = ((denoised - uncond_pred) / cond_scale) + uncond_pred + timestep = timestep[-uncond.shape[0]:] + + from modules_forge.forge_util import cond_from_a1111_to_patched_ldm if "sampler_cfg_function" in model_options: args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, @@ -93,10 +96,15 @@ class CFGDenoiser(torch.nn.Module): # sanity_check = torch.allclose(cfg_result, denoised) for fn in model_options.get("sampler_post_cfg_function", []): - args = {"denoised": cfg_result, "cond": cond, - "uncond": uncond, "model": model, - "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, - "sigma": timestep, "model_options": model_options, "input": x} + args = {"denoised": cfg_result, + "cond": cond_from_a1111_to_patched_ldm(cond), + "uncond": cond_from_a1111_to_patched_ldm(uncond), + "model": model, + "uncond_denoised": uncond_pred, + "cond_denoised": cond_pred, + "sigma": timestep, + "model_options": model_options, + "input": x} cfg_result = fn(args) else: cfg_result = denoised diff --git a/modules_forge/forge_util.py b/modules_forge/forge_util.py new file mode 100644 index 00000000..c8000890 --- /dev/null +++ b/modules_forge/forge_util.py @@ -0,0 +1,17 @@ +from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn + + +def cond_from_a1111_to_patched_ldm(cond): + cross_attn = cond['crossattn'] + pooled_output = cond['vector'] + + result = dict( + cross_attn=cross_attn, + pooled_output=pooled_output, + model_conds=dict( + c_crossattn=CONDCrossAttn(cross_attn), + y=CONDRegular(pooled_output) + ) + ) + + return [result, ]