From b7878058f9144201924cf983e64d08174cf1756e Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Tue, 6 Aug 2024 00:47:33 -0700 Subject: [PATCH] improve backward combability #936 --- backend/diffusion_engine/base.py | 2 ++ backend/diffusion_engine/sd15.py | 1 + backend/diffusion_engine/sd20.py | 4 +++ backend/diffusion_engine/sdxl.py | 1 + backend/sampling/sampling_function.py | 2 +- modules/sd_samplers_cfg_denoiser.py | 42 +++------------------- modules/sd_samplers_kdiffusion.py | 51 +++++++++++++++++---------- modules/sd_samplers_lcm.py | 2 +- modules/sd_samplers_timesteps.py | 20 +++++++---- modules/sd_samplers_timesteps_impl.py | 8 ++--- 10 files changed, 64 insertions(+), 69 deletions(-) diff --git a/backend/diffusion_engine/base.py b/backend/diffusion_engine/base.py index 76048020..424b451b 100644 --- a/backend/diffusion_engine/base.py +++ b/backend/diffusion_engine/base.py @@ -39,6 +39,8 @@ class ForgeDiffusionEngine: self.is_sd2 = False self.is_sdxl = False self.is_sd3 = False + self.parameterization = 'eps' + self.alphas_cumprod = None def set_clip_skip(self, clip_skip): pass diff --git a/backend/diffusion_engine/sd15.py b/backend/diffusion_engine/sd15.py index 2117c79f..0395933e 100644 --- a/backend/diffusion_engine/sd15.py +++ b/backend/diffusion_engine/sd15.py @@ -53,6 +53,7 @@ class StableDiffusion(ForgeDiffusionEngine): # WebUI Legacy self.is_sd1 = True self.first_stage_model = vae.first_stage_model + self.alphas_cumprod = unet.model.predictor.alphas_cumprod def set_clip_skip(self, clip_skip): self.text_processing_engine.clip_skip = clip_skip diff --git a/backend/diffusion_engine/sd20.py b/backend/diffusion_engine/sd20.py index 5620fbb7..d4bc1d7d 100644 --- a/backend/diffusion_engine/sd20.py +++ b/backend/diffusion_engine/sd20.py @@ -53,6 +53,10 @@ class StableDiffusion2(ForgeDiffusionEngine): # WebUI Legacy self.is_sd2 = True self.first_stage_model = vae.first_stage_model + self.alphas_cumprod = unet.model.predictor.alphas_cumprod + + if not self.is_inpaint: + self.parameterization = 'v' def set_clip_skip(self, clip_skip): self.text_processing_engine.clip_skip = clip_skip diff --git a/backend/diffusion_engine/sdxl.py b/backend/diffusion_engine/sdxl.py index 6bb6ecd1..dff8cca2 100644 --- a/backend/diffusion_engine/sdxl.py +++ b/backend/diffusion_engine/sdxl.py @@ -72,6 +72,7 @@ class StableDiffusionXL(ForgeDiffusionEngine): # WebUI Legacy self.is_sdxl = True self.first_stage_model = vae.first_stage_model + self.alphas_cumprod = unet.model.predictor.alphas_cumprod def set_clip_skip(self, clip_skip): self.text_processing_engine_l.clip_skip = clip_skip diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py index aa39961f..a78223af 100644 --- a/backend/sampling/sampling_function.py +++ b/backend/sampling/sampling_function.py @@ -301,7 +301,7 @@ def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_ def sampling_function(self, denoiser_params, cond_scale, cond_composition): - unet_patcher = self.inner_model.forge_objects.unet + unet_patcher = self.inner_model.inner_model.forge_objects.unet model = unet_patcher.model control = unet_patcher.controlnet_linked_list extra_concat_condition = unet_patcher.extra_concat_condition diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 07e80cca..0934ccfe 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -31,10 +31,6 @@ def pad_cond(tensor, repeats, empty): return tensor -def append_zero(x): - return torch.cat([x, x.new_zeros([1])]) - - class CFGDenoiser(torch.nn.Module): """ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) @@ -43,10 +39,8 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, sampler, model): + def __init__(self, sampler): super().__init__() - self.inner_model = model - self.model_wrap = None self.mask = None self.nmask = None @@ -70,35 +64,9 @@ class CFGDenoiser(torch.nn.Module): self.classic_ddim_eps_estimation = False - self.sigmas = model.forge_objects.unet.model.predictor.sigmas - self.log_sigmas = self.sigmas.log() - - def get_sigmas(self, n=None): - if n is None: - return append_zero(self.sigmas.flip(0)) - t_max = len(self.sigmas) - 1 - t = torch.linspace(t_max, 0, n, device=self.sigmas.device) - return append_zero(self.t_to_sigma(t)) - - def sigma_to_t(self, sigma, quantize=None): - quantize = self.quantize if quantize is None else quantize - log_sigma = sigma.log() - dists = log_sigma - self.log_sigmas[:, None] - if quantize: - return dists.abs().argmin(dim=0).view(sigma.shape) - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] - w = (low - log_sigma) / (low - high) - w = w.clamp(0, 1) - t = (1 - w) * low_idx + w * high_idx - return t.view(sigma.shape) - - def t_to_sigma(self, t): - t = t.float() - low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() - log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] - return log_sigma.exp() + @property + def inner_model(self): + raise NotImplementedError() def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond): denoised_uncond = x_out[-uncond.shape[0]:] @@ -190,7 +158,7 @@ class CFGDenoiser(torch.nn.Module): original_x_dtype = x.dtype if self.classic_ddim_eps_estimation: - acd = self.inner_model.alphas_cumprod + acd = self.inner_model.inner_model.alphas_cumprod fake_sigmas = ((1 - acd) / acd) ** 0.5 real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))] real_sigma_data = 1.0 diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 79d94446..f0b88c95 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,11 +1,12 @@ import torch import inspect import k_diffusion.sampling -from modules import sd_samplers_common, sd_samplers_extra, sd_schedulers, devices -from modules.sd_samplers_cfg_denoiser import CFGDenoiser +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices +from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401 from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback from modules.shared import opts +import modules.shared as shared from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup @@ -50,6 +51,21 @@ k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers} +class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): + @property + def inner_model(self): + if self.model_wrap is None: + denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None) + + if denoiser_constructor is not None: + self.model_wrap = denoiser_constructor() + else: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + + return self.model_wrap + + class KDiffusionSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model, options=None): super().__init__(funcname) @@ -59,11 +75,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler): self.options = options or {} self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - self.model_wrap = self.model_wrap_cfg = CFGDenoiser(self, sd_model) - self.predictor = sd_model.forge_objects.unet.model.predictor - - self.model_wrap_cfg.sigmas = self.predictor.sigmas - self.model_wrap_cfg.log_sigmas = self.predictor.sigmas.log() + self.model_wrap_cfg = CFGDenoiserKDiffusion(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -79,7 +92,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): scheduler = sd_schedulers.schedulers_map.get(scheduler_name) - m_sigma_min, m_sigma_max = self.predictor.sigmas[0].item(), self.predictor.sigmas[-1].item() + m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item() sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) if p.sampler_noise_scheduler_override: @@ -107,7 +120,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): p.extra_generation_params["Schedule rho"] = opts.rho if scheduler.need_inner_model: - sigmas_kwargs['inner_model'] = self.model_wrap_cfg + sigmas_kwargs['inner_model'] = self.model_wrap if scheduler.label == 'Beta': p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha @@ -121,11 +134,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler): return sigmas.cpu() def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet - sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x) + unet_patcher = self.model_wrap.inner_model.forge_objects.unet + sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device) - self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device) + self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device) + self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) @@ -183,11 +196,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler): return samples def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet - sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x) + unet_patcher = self.model_wrap.inner_model.forge_objects.unet + sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device) - self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device) + self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device) + self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device) steps = steps or p.steps @@ -206,8 +219,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler): extra_params_kwargs['n'] = steps if 'sigma_min' in parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap_cfg.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap_cfg.sigmas[-1].item() + extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() + extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() if 'sigmas' in parameters: extra_params_kwargs['sigmas'] = sigmas diff --git a/modules/sd_samplers_lcm.py b/modules/sd_samplers_lcm.py index 29d453a2..df4d97b1 100644 --- a/modules/sd_samplers_lcm.py +++ b/modules/sd_samplers_lcm.py @@ -13,7 +13,7 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser): original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM) self.skip_steps = timesteps // original_timesteps - alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device) + alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32) for x in range(original_timesteps): alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps] diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index c5a97290..08956497 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -49,10 +49,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module): class CFGDenoiserTimesteps(CFGDenoiser): - def __init__(self, sampler, model): - super().__init__(sampler, model) + def __init__(self, sampler): + super().__init__(sampler) - self.alphas = model.forge_objects.unet.model.predictor.alphas_cumprod + self.alphas = shared.sd_model.alphas_cumprod self.classic_ddim_eps_estimation = True def get_pred_x0(self, x_in, x_out, sigma): @@ -66,6 +66,14 @@ class CFGDenoiserTimesteps(CFGDenoiser): return pred_x0 + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser + self.model_wrap = denoiser(shared.sd_model) + + return self.model_wrap + class CompVisSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): @@ -75,10 +83,8 @@ class CompVisSampler(sd_samplers_common.Sampler): self.eta_infotext_field = 'Eta DDIM' self.eta_default = 0.0 - self.model_wrap = self.model_wrap_cfg = CFGDenoiserTimesteps(self, sd_model) - self.predictor = sd_model.forge_objects.unet.model.predictor - - self.model_wrap.inner_model.alphas_cumprod = self.predictor.alphas_cumprod + self.model_wrap_cfg = CFGDenoiserTimesteps(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_timesteps(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index a63179e6..180e4389 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -10,7 +10,7 @@ from modules.torch_utils import float64 @torch.no_grad() def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): - alphas_cumprod = model.inner_model.alphas_cumprod + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) @@ -46,7 +46,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction. The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0]. """ - alphas_cumprod = model.inner_model.alphas_cumprod + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) @@ -82,7 +82,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None @torch.no_grad() def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): - alphas_cumprod = model.inner_model.alphas_cumprod + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) @@ -168,7 +168,7 @@ class UniPCCFG(uni_pc.UniPC): def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False): - alphas_cumprod = model.inner_model.alphas_cumprod + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means