diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 4289c4b4..bb86aaa9 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -153,6 +153,16 @@ class CFGDenoiser(torch.nn.Module): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + classic_ddim_eps_estimation = 'timesteps' in type(self).__name__.lower() + + if classic_ddim_eps_estimation: + acd = self.inner_model.inner_model.alphas_cumprod + fake_sigmas = ((1 - acd) / acd) ** 0.5 + real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))] + real_sigma_data = 1.0 + x = x * (real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5 + sigma = real_sigma + if sd_samplers_common.apply_refiner(self, x): cond = self.sampler.sampler_extra_args['cond'] uncond = self.sampler.sampler_extra_args['uncond'] @@ -194,5 +204,10 @@ class CFGDenoiser(torch.nn.Module): denoised = after_cfg_callback_params.x self.step += 1 + + if classic_ddim_eps_estimation: + eps = (x - denoised) / sigma + return eps + return denoised