fix ddim plms unipc

This commit is contained in:
lllyasviel
2024-02-04 02:03:35 -08:00
parent 0a5ac13b14
commit 1024b67298

View File

@@ -153,6 +153,16 @@ class CFGDenoiser(torch.nn.Module):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
classic_ddim_eps_estimation = 'timesteps' in type(self).__name__.lower()
if classic_ddim_eps_estimation:
acd = self.inner_model.inner_model.alphas_cumprod
fake_sigmas = ((1 - acd) / acd) ** 0.5
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
real_sigma_data = 1.0
x = x * (real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5
sigma = real_sigma
if sd_samplers_common.apply_refiner(self, x):
cond = self.sampler.sampler_extra_args['cond']
uncond = self.sampler.sampler_extra_args['uncond']
@@ -194,5 +204,10 @@ class CFGDenoiser(torch.nn.Module):
denoised = after_cfg_callback_params.x
self.step += 1
if classic_ddim_eps_estimation:
eps = (x - denoised) / sigma
return eps
return denoised