From ee023f4fbf47a20773289e1e0ef2494e38212b17 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 10 Feb 2024 18:35:42 -0800 Subject: [PATCH] Fix UniPC data cast and shape broadcast in #184 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix UniPC data cast and shape broadcast in #184 This also fix potential problems in DDIM. The cause of this BUG is A1111’s `modules\models\diffusion\uni_pc\uni_pc.py` does not have a data cast and the Forge’s DDIM estimator forget to match the broadcast shape of sigmas. (At the same time when we are fixing this BUG in A1111’s very original and high-quality samplers, comfyanonymous is still believing that Forge is using comfyui to sample images, eg, ComfyUI UniPC. Comfyanonymous is so cute. See also the jokes here https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/169#discussioncomment-8428689) --- modules/models/diffusion/uni_pc/uni_pc.py | 2 +- modules/sd_samplers_cfg_denoiser.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index d257a728..4a365151 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -445,7 +445,7 @@ class UniPC: s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s - return x0 + return x0.to(x) def model_fn(self, x, t): """ diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index b1dca624..e2d3826b 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -162,7 +162,7 @@ class CFGDenoiser(torch.nn.Module): fake_sigmas = ((1 - acd) / acd) ** 0.5 real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))] real_sigma_data = 1.0 - x = x * (real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5 + x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None]) sigma = real_sigma if sd_samplers_common.apply_refiner(self, x): @@ -195,7 +195,7 @@ class CFGDenoiser(torch.nn.Module): self.step += 1 if self.classic_ddim_eps_estimation: - eps = (x - denoised) / sigma + eps = (x - denoised) / sigma[:, None, None, None] return eps return denoised.to(device=original_x_device, dtype=original_x_dtype)