From 01c422678cab9cbdf4c3ed485df4500bcbc94075 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 19:52:03 -0800 Subject: [PATCH] Update sd_samplers_cfg_denoiser.py --- modules/sd_samplers_cfg_denoiser.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 9ec6800f..fb2c2834 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -65,7 +65,7 @@ class CFGDenoiser(torch.nn.Module): def inner_model(self): raise NotImplementedError() - def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in): + def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond): model_options = self.inner_model.inner_model.forge_objects.unet.model_options denoised_uncond = x_out[-uncond.shape[0]:] @@ -93,8 +93,8 @@ class CFGDenoiser(torch.nn.Module): # sanity_check = torch.allclose(cfg_result, denoised) for fn in model_options.get("sampler_post_cfg_function", []): - args = {"denoised": cfg_result, "cond": torch.zeros_like(uncond), - "uncond": torch.zeros_like(uncond), "model": model, + args = {"denoised": cfg_result, "cond": cond, + "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, "sigma": timestep, "model_options": model_options, "input": x} cfg_result = fn(args) @@ -265,9 +265,9 @@ class CFGDenoiser(torch.nn.Module): if is_edit_model: denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) elif skip_uncond: - denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0, sigma_in, x_in) + denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0, sigma_in, x_in, tensor) else: - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, sigma_in, x_in) + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, sigma_in, x_in, tensor) # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: