diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 770ab376..9ec6800f 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -90,9 +90,11 @@ class CFGDenoiser(torch.nn.Module): cfg_result = x - model_options["sampler_cfg_function"](args) else: cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + # sanity_check = torch.allclose(cfg_result, denoised) for fn in model_options.get("sampler_post_cfg_function", []): - args = {"denoised": cfg_result, "cond": torch.zeros_like(uncond), "uncond": torch.zeros_like(uncond), "model": model, + args = {"denoised": cfg_result, "cond": torch.zeros_like(uncond), + "uncond": torch.zeros_like(uncond), "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, "sigma": timestep, "model_options": model_options, "input": x} cfg_result = fn(args)