This commit is contained in:
lllyasviel
2024-02-04 13:47:31 -08:00
parent d32b5e5458
commit 54d12ef989
2 changed files with 6 additions and 7 deletions

View File

@@ -59,10 +59,11 @@ class CFGDenoiser(torch.nn.Module):
self.model_wrap = None self.model_wrap = None
self.p = None self.p = None
# NOTE: masking before denoising can cause the original latents to be oversmoothed # Backward Compatibility
# as the original latents do not have noise
self.mask_before_denoising = False self.mask_before_denoising = False
self.classic_ddim_eps_estimation = False
@property @property
def inner_model(self): def inner_model(self):
raise NotImplementedError() raise NotImplementedError()
@@ -153,9 +154,7 @@ class CFGDenoiser(torch.nn.Module):
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException raise sd_samplers_common.InterruptedException
classic_ddim_eps_estimation = 'timesteps' in type(self).__name__.lower() if self.classic_ddim_eps_estimation:
if classic_ddim_eps_estimation:
acd = self.inner_model.inner_model.alphas_cumprod acd = self.inner_model.inner_model.alphas_cumprod
fake_sigmas = ((1 - acd) / acd) ** 0.5 fake_sigmas = ((1 - acd) / acd) ** 0.5
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))] real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
@@ -205,7 +204,7 @@ class CFGDenoiser(torch.nn.Module):
self.step += 1 self.step += 1
if classic_ddim_eps_estimation: if self.classic_ddim_eps_estimation:
eps = (x - denoised) / sigma eps = (x - denoised) / sigma
return eps return eps

View File

@@ -52,7 +52,7 @@ class CFGDenoiserTimesteps(CFGDenoiser):
super().__init__(sampler) super().__init__(sampler)
self.alphas = shared.sd_model.alphas_cumprod self.alphas = shared.sd_model.alphas_cumprod
self.mask_before_denoising = True self.classic_ddim_eps_estimation = True
def get_pred_x0(self, x_in, x_out, sigma): def get_pred_x0(self, x_in, x_out, sigma):
ts = sigma.to(dtype=int) ts = sigma.to(dtype=int)