From 54f07f1d1d922be068e083a4896da96bbde37c7d Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 19:43:06 -0800 Subject: [PATCH] Update sd_samplers_cfg_denoiser.py --- modules/sd_samplers_cfg_denoiser.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 17eabf3d..770ab376 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -65,7 +65,7 @@ class CFGDenoiser(torch.nn.Module): def inner_model(self): raise NotImplementedError() - def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x): + def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in): model_options = self.inner_model.inner_model.forge_objects.unet.model_options denoised_uncond = x_out[-uncond.shape[0]:] @@ -78,10 +78,28 @@ class CFGDenoiser(torch.nn.Module): if "sampler_cfg_function" in model_options or "sampler_post_cfg_function" in model_options: cond_scale = float(cond_scale) model = self.inner_model.inner_model.forge_objects.unet + x = x_in[-uncond.shape[0]:] + uncond_pred = denoised_uncond + cond_pred = ((denoised - uncond_pred) / cond_scale) + uncond_pred - a = 0 + if "sampler_cfg_function" in model_options: + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, + "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, + "model_options": model_options} + cfg_result = x - model_options["sampler_cfg_function"](args) + else: + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale - return denoised + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": torch.zeros_like(uncond), "uncond": torch.zeros_like(uncond), "model": model, + "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) + else: + cfg_result = denoised + + return cfg_result def combine_denoised_for_edit_model(self, x_out, cond_scale): out_cond, out_img_cond, out_uncond = x_out.chunk(3)