mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
improve backward combability #936
This commit is contained in:
@@ -39,6 +39,8 @@ class ForgeDiffusionEngine:
|
|||||||
self.is_sd2 = False
|
self.is_sd2 = False
|
||||||
self.is_sdxl = False
|
self.is_sdxl = False
|
||||||
self.is_sd3 = False
|
self.is_sd3 = False
|
||||||
|
self.parameterization = 'eps'
|
||||||
|
self.alphas_cumprod = None
|
||||||
|
|
||||||
def set_clip_skip(self, clip_skip):
|
def set_clip_skip(self, clip_skip):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class StableDiffusion(ForgeDiffusionEngine):
|
|||||||
# WebUI Legacy
|
# WebUI Legacy
|
||||||
self.is_sd1 = True
|
self.is_sd1 = True
|
||||||
self.first_stage_model = vae.first_stage_model
|
self.first_stage_model = vae.first_stage_model
|
||||||
|
self.alphas_cumprod = unet.model.predictor.alphas_cumprod
|
||||||
|
|
||||||
def set_clip_skip(self, clip_skip):
|
def set_clip_skip(self, clip_skip):
|
||||||
self.text_processing_engine.clip_skip = clip_skip
|
self.text_processing_engine.clip_skip = clip_skip
|
||||||
|
|||||||
@@ -53,6 +53,10 @@ class StableDiffusion2(ForgeDiffusionEngine):
|
|||||||
# WebUI Legacy
|
# WebUI Legacy
|
||||||
self.is_sd2 = True
|
self.is_sd2 = True
|
||||||
self.first_stage_model = vae.first_stage_model
|
self.first_stage_model = vae.first_stage_model
|
||||||
|
self.alphas_cumprod = unet.model.predictor.alphas_cumprod
|
||||||
|
|
||||||
|
if not self.is_inpaint:
|
||||||
|
self.parameterization = 'v'
|
||||||
|
|
||||||
def set_clip_skip(self, clip_skip):
|
def set_clip_skip(self, clip_skip):
|
||||||
self.text_processing_engine.clip_skip = clip_skip
|
self.text_processing_engine.clip_skip = clip_skip
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ class StableDiffusionXL(ForgeDiffusionEngine):
|
|||||||
# WebUI Legacy
|
# WebUI Legacy
|
||||||
self.is_sdxl = True
|
self.is_sdxl = True
|
||||||
self.first_stage_model = vae.first_stage_model
|
self.first_stage_model = vae.first_stage_model
|
||||||
|
self.alphas_cumprod = unet.model.predictor.alphas_cumprod
|
||||||
|
|
||||||
def set_clip_skip(self, clip_skip):
|
def set_clip_skip(self, clip_skip):
|
||||||
self.text_processing_engine_l.clip_skip = clip_skip
|
self.text_processing_engine_l.clip_skip = clip_skip
|
||||||
|
|||||||
@@ -301,7 +301,7 @@ def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_
|
|||||||
|
|
||||||
|
|
||||||
def sampling_function(self, denoiser_params, cond_scale, cond_composition):
|
def sampling_function(self, denoiser_params, cond_scale, cond_composition):
|
||||||
unet_patcher = self.inner_model.forge_objects.unet
|
unet_patcher = self.inner_model.inner_model.forge_objects.unet
|
||||||
model = unet_patcher.model
|
model = unet_patcher.model
|
||||||
control = unet_patcher.controlnet_linked_list
|
control = unet_patcher.controlnet_linked_list
|
||||||
extra_concat_condition = unet_patcher.extra_concat_condition
|
extra_concat_condition = unet_patcher.extra_concat_condition
|
||||||
|
|||||||
@@ -31,10 +31,6 @@ def pad_cond(tensor, repeats, empty):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def append_zero(x):
|
|
||||||
return torch.cat([x, x.new_zeros([1])])
|
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(torch.nn.Module):
|
class CFGDenoiser(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||||
@@ -43,10 +39,8 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
negative prompt.
|
negative prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sampler, model):
|
def __init__(self, sampler):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
|
||||||
|
|
||||||
self.model_wrap = None
|
self.model_wrap = None
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
@@ -70,35 +64,9 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
self.classic_ddim_eps_estimation = False
|
self.classic_ddim_eps_estimation = False
|
||||||
|
|
||||||
self.sigmas = model.forge_objects.unet.model.predictor.sigmas
|
@property
|
||||||
self.log_sigmas = self.sigmas.log()
|
def inner_model(self):
|
||||||
|
raise NotImplementedError()
|
||||||
def get_sigmas(self, n=None):
|
|
||||||
if n is None:
|
|
||||||
return append_zero(self.sigmas.flip(0))
|
|
||||||
t_max = len(self.sigmas) - 1
|
|
||||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
|
||||||
return append_zero(self.t_to_sigma(t))
|
|
||||||
|
|
||||||
def sigma_to_t(self, sigma, quantize=None):
|
|
||||||
quantize = self.quantize if quantize is None else quantize
|
|
||||||
log_sigma = sigma.log()
|
|
||||||
dists = log_sigma - self.log_sigmas[:, None]
|
|
||||||
if quantize:
|
|
||||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
|
||||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
|
||||||
high_idx = low_idx + 1
|
|
||||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
|
||||||
w = (low - log_sigma) / (low - high)
|
|
||||||
w = w.clamp(0, 1)
|
|
||||||
t = (1 - w) * low_idx + w * high_idx
|
|
||||||
return t.view(sigma.shape)
|
|
||||||
|
|
||||||
def t_to_sigma(self, t):
|
|
||||||
t = t.float()
|
|
||||||
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
|
||||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
|
||||||
return log_sigma.exp()
|
|
||||||
|
|
||||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
@@ -190,7 +158,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
original_x_dtype = x.dtype
|
original_x_dtype = x.dtype
|
||||||
|
|
||||||
if self.classic_ddim_eps_estimation:
|
if self.classic_ddim_eps_estimation:
|
||||||
acd = self.inner_model.alphas_cumprod
|
acd = self.inner_model.inner_model.alphas_cumprod
|
||||||
fake_sigmas = ((1 - acd) / acd) ** 0.5
|
fake_sigmas = ((1 - acd) / acd) ** 0.5
|
||||||
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
|
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
|
||||||
real_sigma_data = 1.0
|
real_sigma_data = 1.0
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import sd_samplers_common, sd_samplers_extra, sd_schedulers, devices
|
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices
|
||||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
|
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
||||||
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||||
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
import modules.shared as shared
|
||||||
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
|
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
|
||||||
|
|
||||||
|
|
||||||
@@ -50,6 +51,21 @@ k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
|||||||
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
|
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
|
||||||
|
|
||||||
|
if denoiser_constructor is not None:
|
||||||
|
self.model_wrap = denoiser_constructor()
|
||||||
|
else:
|
||||||
|
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||||
|
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
||||||
|
|
||||||
|
return self.model_wrap
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler(sd_samplers_common.Sampler):
|
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||||
def __init__(self, funcname, sd_model, options=None):
|
def __init__(self, funcname, sd_model, options=None):
|
||||||
super().__init__(funcname)
|
super().__init__(funcname)
|
||||||
@@ -59,11 +75,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
self.options = options or {}
|
self.options = options or {}
|
||||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||||
|
|
||||||
self.model_wrap = self.model_wrap_cfg = CFGDenoiser(self, sd_model)
|
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
|
||||||
self.predictor = sd_model.forge_objects.unet.model.predictor
|
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||||
|
|
||||||
self.model_wrap_cfg.sigmas = self.predictor.sigmas
|
|
||||||
self.model_wrap_cfg.log_sigmas = self.predictor.sigmas.log()
|
|
||||||
|
|
||||||
def get_sigmas(self, p, steps):
|
def get_sigmas(self, p, steps):
|
||||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
@@ -79,7 +92,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
|
|
||||||
scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
|
scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
|
||||||
|
|
||||||
m_sigma_min, m_sigma_max = self.predictor.sigmas[0].item(), self.predictor.sigmas[-1].item()
|
m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
|
||||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
@@ -107,7 +120,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
p.extra_generation_params["Schedule rho"] = opts.rho
|
p.extra_generation_params["Schedule rho"] = opts.rho
|
||||||
|
|
||||||
if scheduler.need_inner_model:
|
if scheduler.need_inner_model:
|
||||||
sigmas_kwargs['inner_model'] = self.model_wrap_cfg
|
sigmas_kwargs['inner_model'] = self.model_wrap
|
||||||
|
|
||||||
if scheduler.label == 'Beta':
|
if scheduler.label == 'Beta':
|
||||||
p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha
|
p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha
|
||||||
@@ -121,11 +134,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
return sigmas.cpu()
|
return sigmas.cpu()
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet
|
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||||
sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x)
|
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||||
|
|
||||||
self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device)
|
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
|
||||||
self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device)
|
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
|
||||||
|
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
@@ -183,11 +196,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet
|
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||||
sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x)
|
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||||
|
|
||||||
self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device)
|
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
|
||||||
self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device)
|
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
|
||||||
|
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
@@ -206,8 +219,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
extra_params_kwargs['n'] = steps
|
extra_params_kwargs['n'] = steps
|
||||||
|
|
||||||
if 'sigma_min' in parameters:
|
if 'sigma_min' in parameters:
|
||||||
extra_params_kwargs['sigma_min'] = self.model_wrap_cfg.sigmas[0].item()
|
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||||
extra_params_kwargs['sigma_max'] = self.model_wrap_cfg.sigmas[-1].item()
|
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||||
|
|
||||||
if 'sigmas' in parameters:
|
if 'sigmas' in parameters:
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
|||||||
original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
|
original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
|
||||||
self.skip_steps = timesteps // original_timesteps
|
self.skip_steps = timesteps // original_timesteps
|
||||||
|
|
||||||
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device)
|
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
|
||||||
for x in range(original_timesteps):
|
for x in range(original_timesteps):
|
||||||
alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||||
|
|
||||||
|
|||||||
@@ -49,10 +49,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
class CFGDenoiserTimesteps(CFGDenoiser):
|
class CFGDenoiserTimesteps(CFGDenoiser):
|
||||||
|
|
||||||
def __init__(self, sampler, model):
|
def __init__(self, sampler):
|
||||||
super().__init__(sampler, model)
|
super().__init__(sampler)
|
||||||
|
|
||||||
self.alphas = model.forge_objects.unet.model.predictor.alphas_cumprod
|
self.alphas = shared.sd_model.alphas_cumprod
|
||||||
self.classic_ddim_eps_estimation = True
|
self.classic_ddim_eps_estimation = True
|
||||||
|
|
||||||
def get_pred_x0(self, x_in, x_out, sigma):
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
@@ -66,6 +66,14 @@ class CFGDenoiserTimesteps(CFGDenoiser):
|
|||||||
|
|
||||||
return pred_x0
|
return pred_x0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
||||||
|
self.model_wrap = denoiser(shared.sd_model)
|
||||||
|
|
||||||
|
return self.model_wrap
|
||||||
|
|
||||||
|
|
||||||
class CompVisSampler(sd_samplers_common.Sampler):
|
class CompVisSampler(sd_samplers_common.Sampler):
|
||||||
def __init__(self, funcname, sd_model):
|
def __init__(self, funcname, sd_model):
|
||||||
@@ -75,10 +83,8 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
|||||||
self.eta_infotext_field = 'Eta DDIM'
|
self.eta_infotext_field = 'Eta DDIM'
|
||||||
self.eta_default = 0.0
|
self.eta_default = 0.0
|
||||||
|
|
||||||
self.model_wrap = self.model_wrap_cfg = CFGDenoiserTimesteps(self, sd_model)
|
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
|
||||||
self.predictor = sd_model.forge_objects.unet.model.predictor
|
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||||
|
|
||||||
self.model_wrap.inner_model.alphas_cumprod = self.predictor.alphas_cumprod
|
|
||||||
|
|
||||||
def get_timesteps(self, p, steps):
|
def get_timesteps(self, p, steps):
|
||||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from modules.torch_utils import float64
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
alphas = alphas_cumprod[timesteps]
|
alphas = alphas_cumprod[timesteps]
|
||||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
||||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
@@ -46,7 +46,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
|
|||||||
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
|
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
|
||||||
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
|
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
|
||||||
"""
|
"""
|
||||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
alphas = alphas_cumprod[timesteps]
|
alphas = alphas_cumprod[timesteps]
|
||||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
||||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
@@ -82,7 +82,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
alphas = alphas_cumprod[timesteps]
|
alphas = alphas_cumprod[timesteps]
|
||||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
||||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
@@ -168,7 +168,7 @@ class UniPCCFG(uni_pc.UniPC):
|
|||||||
|
|
||||||
|
|
||||||
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
|
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
|
||||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
|
||||||
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||||
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
|
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
|
||||||
|
|||||||
Reference in New Issue
Block a user