improve backward combability #936

This commit is contained in:
layerdiffusion
2024-08-06 00:47:33 -07:00
parent e8e5fdee8a
commit b7878058f9
10 changed files with 64 additions and 69 deletions

View File

@@ -10,7 +10,7 @@ from modules.torch_utils import float64
@torch.no_grad()
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.alphas_cumprod
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -46,7 +46,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
"""
alphas_cumprod = model.inner_model.alphas_cumprod
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -82,7 +82,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
@torch.no_grad()
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.alphas_cumprod
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -168,7 +168,7 @@ class UniPCCFG(uni_pc.UniPC):
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
alphas_cumprod = model.inner_model.alphas_cumprod
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means