From 1024b6729813bb443f8778372f40606d40d28809 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sun, 4 Feb 2024 02:03:35 -0800 Subject: [PATCH] fix ddim plms unipc --- modules/sd_samplers_cfg_denoiser.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 4289c4b4..bb86aaa9 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -153,6 +153,16 @@ class CFGDenoiser(torch.nn.Module): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + classic_ddim_eps_estimation = 'timesteps' in type(self).__name__.lower() + + if classic_ddim_eps_estimation: + acd = self.inner_model.inner_model.alphas_cumprod + fake_sigmas = ((1 - acd) / acd) ** 0.5 + real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))] + real_sigma_data = 1.0 + x = x * (real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5 + sigma = real_sigma + if sd_samplers_common.apply_refiner(self, x): cond = self.sampler.sampler_extra_args['cond'] uncond = self.sampler.sampler_extra_args['uncond'] @@ -194,5 +204,10 @@ class CFGDenoiser(torch.nn.Module): denoised = after_cfg_callback_params.x self.step += 1 + + if classic_ddim_eps_estimation: + eps = (x - denoised) / sigma + return eps + return denoised