improve backward combability #936

This commit is contained in:
layerdiffusion
2024-08-06 00:47:33 -07:00
parent e8e5fdee8a
commit b7878058f9
10 changed files with 64 additions and 69 deletions

View File

@@ -39,6 +39,8 @@ class ForgeDiffusionEngine:
self.is_sd2 = False self.is_sd2 = False
self.is_sdxl = False self.is_sdxl = False
self.is_sd3 = False self.is_sd3 = False
self.parameterization = 'eps'
self.alphas_cumprod = None
def set_clip_skip(self, clip_skip): def set_clip_skip(self, clip_skip):
pass pass

View File

@@ -53,6 +53,7 @@ class StableDiffusion(ForgeDiffusionEngine):
# WebUI Legacy # WebUI Legacy
self.is_sd1 = True self.is_sd1 = True
self.first_stage_model = vae.first_stage_model self.first_stage_model = vae.first_stage_model
self.alphas_cumprod = unet.model.predictor.alphas_cumprod
def set_clip_skip(self, clip_skip): def set_clip_skip(self, clip_skip):
self.text_processing_engine.clip_skip = clip_skip self.text_processing_engine.clip_skip = clip_skip

View File

@@ -53,6 +53,10 @@ class StableDiffusion2(ForgeDiffusionEngine):
# WebUI Legacy # WebUI Legacy
self.is_sd2 = True self.is_sd2 = True
self.first_stage_model = vae.first_stage_model self.first_stage_model = vae.first_stage_model
self.alphas_cumprod = unet.model.predictor.alphas_cumprod
if not self.is_inpaint:
self.parameterization = 'v'
def set_clip_skip(self, clip_skip): def set_clip_skip(self, clip_skip):
self.text_processing_engine.clip_skip = clip_skip self.text_processing_engine.clip_skip = clip_skip

View File

@@ -72,6 +72,7 @@ class StableDiffusionXL(ForgeDiffusionEngine):
# WebUI Legacy # WebUI Legacy
self.is_sdxl = True self.is_sdxl = True
self.first_stage_model = vae.first_stage_model self.first_stage_model = vae.first_stage_model
self.alphas_cumprod = unet.model.predictor.alphas_cumprod
def set_clip_skip(self, clip_skip): def set_clip_skip(self, clip_skip):
self.text_processing_engine_l.clip_skip = clip_skip self.text_processing_engine_l.clip_skip = clip_skip

View File

@@ -301,7 +301,7 @@ def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_
def sampling_function(self, denoiser_params, cond_scale, cond_composition): def sampling_function(self, denoiser_params, cond_scale, cond_composition):
unet_patcher = self.inner_model.forge_objects.unet unet_patcher = self.inner_model.inner_model.forge_objects.unet
model = unet_patcher.model model = unet_patcher.model
control = unet_patcher.controlnet_linked_list control = unet_patcher.controlnet_linked_list
extra_concat_condition = unet_patcher.extra_concat_condition extra_concat_condition = unet_patcher.extra_concat_condition

View File

@@ -31,10 +31,6 @@ def pad_cond(tensor, repeats, empty):
return tensor return tensor
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
class CFGDenoiser(torch.nn.Module): class CFGDenoiser(torch.nn.Module):
""" """
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
@@ -43,10 +39,8 @@ class CFGDenoiser(torch.nn.Module):
negative prompt. negative prompt.
""" """
def __init__(self, sampler, model): def __init__(self, sampler):
super().__init__() super().__init__()
self.inner_model = model
self.model_wrap = None self.model_wrap = None
self.mask = None self.mask = None
self.nmask = None self.nmask = None
@@ -70,35 +64,9 @@ class CFGDenoiser(torch.nn.Module):
self.classic_ddim_eps_estimation = False self.classic_ddim_eps_estimation = False
self.sigmas = model.forge_objects.unet.model.predictor.sigmas @property
self.log_sigmas = self.sigmas.log() def inner_model(self):
raise NotImplementedError()
def get_sigmas(self, n=None):
if n is None:
return append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
log_sigma = sigma.log()
dists = log_sigma - self.log_sigmas[:, None]
if quantize:
return dists.abs().argmin(dim=0).view(sigma.shape)
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond): def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
denoised_uncond = x_out[-uncond.shape[0]:] denoised_uncond = x_out[-uncond.shape[0]:]
@@ -190,7 +158,7 @@ class CFGDenoiser(torch.nn.Module):
original_x_dtype = x.dtype original_x_dtype = x.dtype
if self.classic_ddim_eps_estimation: if self.classic_ddim_eps_estimation:
acd = self.inner_model.alphas_cumprod acd = self.inner_model.inner_model.alphas_cumprod
fake_sigmas = ((1 - acd) / acd) ** 0.5 fake_sigmas = ((1 - acd) / acd) ** 0.5
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))] real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
real_sigma_data = 1.0 real_sigma_data = 1.0

View File

@@ -1,11 +1,12 @@
import torch import torch
import inspect import inspect
import k_diffusion.sampling import k_diffusion.sampling
from modules import sd_samplers_common, sd_samplers_extra, sd_schedulers, devices from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices
from modules.sd_samplers_cfg_denoiser import CFGDenoiser from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts from modules.shared import opts
import modules.shared as shared
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
@@ -50,6 +51,21 @@ k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers} k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@property
def inner_model(self):
if self.model_wrap is None:
denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
if denoiser_constructor is not None:
self.model_wrap = denoiser_constructor()
else:
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
return self.model_wrap
class KDiffusionSampler(sd_samplers_common.Sampler): class KDiffusionSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model, options=None): def __init__(self, funcname, sd_model, options=None):
super().__init__(funcname) super().__init__(funcname)
@@ -59,11 +75,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
self.options = options or {} self.options = options or {}
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.model_wrap = self.model_wrap_cfg = CFGDenoiser(self, sd_model) self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
self.predictor = sd_model.forge_objects.unet.model.predictor self.model_wrap = self.model_wrap_cfg.inner_model
self.model_wrap_cfg.sigmas = self.predictor.sigmas
self.model_wrap_cfg.log_sigmas = self.predictor.sigmas.log()
def get_sigmas(self, p, steps): def get_sigmas(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -79,7 +92,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
scheduler = sd_schedulers.schedulers_map.get(scheduler_name) scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
m_sigma_min, m_sigma_max = self.predictor.sigmas[0].item(), self.predictor.sigmas[-1].item() m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
if p.sampler_noise_scheduler_override: if p.sampler_noise_scheduler_override:
@@ -107,7 +120,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
p.extra_generation_params["Schedule rho"] = opts.rho p.extra_generation_params["Schedule rho"] = opts.rho
if scheduler.need_inner_model: if scheduler.need_inner_model:
sigmas_kwargs['inner_model'] = self.model_wrap_cfg sigmas_kwargs['inner_model'] = self.model_wrap
if scheduler.label == 'Beta': if scheduler.label == 'Beta':
p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha
@@ -121,11 +134,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
return sigmas.cpu() return sigmas.cpu()
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x) sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device) self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device) self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
@@ -183,11 +196,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
return samples return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x) sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device) self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device) self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
steps = steps or p.steps steps = steps or p.steps
@@ -206,8 +219,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
extra_params_kwargs['n'] = steps extra_params_kwargs['n'] = steps
if 'sigma_min' in parameters: if 'sigma_min' in parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap_cfg.sigmas[0].item() extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap_cfg.sigmas[-1].item() extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
if 'sigmas' in parameters: if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigmas extra_params_kwargs['sigmas'] = sigmas

View File

@@ -13,7 +13,7 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM) original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
self.skip_steps = timesteps // original_timesteps self.skip_steps = timesteps // original_timesteps
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device) alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
for x in range(original_timesteps): for x in range(original_timesteps):
alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps] alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]

View File

@@ -49,10 +49,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
class CFGDenoiserTimesteps(CFGDenoiser): class CFGDenoiserTimesteps(CFGDenoiser):
def __init__(self, sampler, model): def __init__(self, sampler):
super().__init__(sampler, model) super().__init__(sampler)
self.alphas = model.forge_objects.unet.model.predictor.alphas_cumprod self.alphas = shared.sd_model.alphas_cumprod
self.classic_ddim_eps_estimation = True self.classic_ddim_eps_estimation = True
def get_pred_x0(self, x_in, x_out, sigma): def get_pred_x0(self, x_in, x_out, sigma):
@@ -66,6 +66,14 @@ class CFGDenoiserTimesteps(CFGDenoiser):
return pred_x0 return pred_x0
@property
def inner_model(self):
if self.model_wrap is None:
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
self.model_wrap = denoiser(shared.sd_model)
return self.model_wrap
class CompVisSampler(sd_samplers_common.Sampler): class CompVisSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model): def __init__(self, funcname, sd_model):
@@ -75,10 +83,8 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_infotext_field = 'Eta DDIM' self.eta_infotext_field = 'Eta DDIM'
self.eta_default = 0.0 self.eta_default = 0.0
self.model_wrap = self.model_wrap_cfg = CFGDenoiserTimesteps(self, sd_model) self.model_wrap_cfg = CFGDenoiserTimesteps(self)
self.predictor = sd_model.forge_objects.unet.model.predictor self.model_wrap = self.model_wrap_cfg.inner_model
self.model_wrap.inner_model.alphas_cumprod = self.predictor.alphas_cumprod
def get_timesteps(self, p, steps): def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)

View File

@@ -10,7 +10,7 @@ from modules.torch_utils import float64
@torch.no_grad() @torch.no_grad()
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -46,7 +46,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction. Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0]. The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
""" """
alphas_cumprod = model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -82,7 +82,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
@torch.no_grad() @torch.no_grad()
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -168,7 +168,7 @@ class UniPCCFG(uni_pc.UniPC):
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False): def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
alphas_cumprod = model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means