Update sd_samplers_cfg_denoiser.py

This commit is contained in:
lllyasviel
2024-01-25 19:44:25 -08:00
parent 54f07f1d1d
commit d3cb546cc5

View File

@@ -90,9 +90,11 @@ class CFGDenoiser(torch.nn.Module):
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
# sanity_check = torch.allclose(cfg_result, denoised)
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": torch.zeros_like(uncond), "uncond": torch.zeros_like(uncond), "model": model,
args = {"denoised": cfg_result, "cond": torch.zeros_like(uncond),
"uncond": torch.zeros_like(uncond), "model": model,
"uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)