This commit is contained in:
lllyasviel
2024-02-04 13:47:31 -08:00
parent d32b5e5458
commit 54d12ef989
2 changed files with 6 additions and 7 deletions

View File

@@ -59,10 +59,11 @@ class CFGDenoiser(torch.nn.Module):
self.model_wrap = None
self.p = None
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
# Backward Compatibility
self.mask_before_denoising = False
self.classic_ddim_eps_estimation = False
@property
def inner_model(self):
raise NotImplementedError()
@@ -153,9 +154,7 @@ class CFGDenoiser(torch.nn.Module):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
classic_ddim_eps_estimation = 'timesteps' in type(self).__name__.lower()
if classic_ddim_eps_estimation:
if self.classic_ddim_eps_estimation:
acd = self.inner_model.inner_model.alphas_cumprod
fake_sigmas = ((1 - acd) / acd) ** 0.5
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
@@ -205,7 +204,7 @@ class CFGDenoiser(torch.nn.Module):
self.step += 1
if classic_ddim_eps_estimation:
if self.classic_ddim_eps_estimation:
eps = (x - denoised) / sigma
return eps

View File

@@ -52,7 +52,7 @@ class CFGDenoiserTimesteps(CFGDenoiser):
super().__init__(sampler)
self.alphas = shared.sd_model.alphas_cumprod
self.mask_before_denoising = True
self.classic_ddim_eps_estimation = True
def get_pred_x0(self, x_in, x_out, sigma):
ts = sigma.to(dtype=int)