diff --git a/backend/diffusion_engine/base.py b/backend/diffusion_engine/base.py index 424b451b..76048020 100644 --- a/backend/diffusion_engine/base.py +++ b/backend/diffusion_engine/base.py @@ -39,8 +39,6 @@ class ForgeDiffusionEngine: self.is_sd2 = False self.is_sdxl = False self.is_sd3 = False - self.parameterization = 'eps' - self.alphas_cumprod = None def set_clip_skip(self, clip_skip): pass diff --git a/backend/diffusion_engine/sd15.py b/backend/diffusion_engine/sd15.py index 0395933e..2117c79f 100644 --- a/backend/diffusion_engine/sd15.py +++ b/backend/diffusion_engine/sd15.py @@ -53,7 +53,6 @@ class StableDiffusion(ForgeDiffusionEngine): # WebUI Legacy self.is_sd1 = True self.first_stage_model = vae.first_stage_model - self.alphas_cumprod = unet.model.predictor.alphas_cumprod def set_clip_skip(self, clip_skip): self.text_processing_engine.clip_skip = clip_skip diff --git a/backend/diffusion_engine/sd20.py b/backend/diffusion_engine/sd20.py index d4bc1d7d..5620fbb7 100644 --- a/backend/diffusion_engine/sd20.py +++ b/backend/diffusion_engine/sd20.py @@ -53,10 +53,6 @@ class StableDiffusion2(ForgeDiffusionEngine): # WebUI Legacy self.is_sd2 = True self.first_stage_model = vae.first_stage_model - self.alphas_cumprod = unet.model.predictor.alphas_cumprod - - if not self.is_inpaint: - self.parameterization = 'v' def set_clip_skip(self, clip_skip): self.text_processing_engine.clip_skip = clip_skip diff --git a/backend/diffusion_engine/sdxl.py b/backend/diffusion_engine/sdxl.py index dff8cca2..6bb6ecd1 100644 --- a/backend/diffusion_engine/sdxl.py +++ b/backend/diffusion_engine/sdxl.py @@ -72,7 +72,6 @@ class StableDiffusionXL(ForgeDiffusionEngine): # WebUI Legacy self.is_sdxl = True self.first_stage_model = vae.first_stage_model - self.alphas_cumprod = unet.model.predictor.alphas_cumprod def set_clip_skip(self, clip_skip): self.text_processing_engine_l.clip_skip = clip_skip diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py index a78223af..c51ee742 100644 --- a/backend/sampling/sampling_function.py +++ b/backend/sampling/sampling_function.py @@ -270,7 +270,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): return out_cond, out_uncond -def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): +def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None, return_full=False): edit_strength = sum((item['strength'] if 'strength' in item else 1) for item in cond) if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: @@ -297,6 +297,9 @@ def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_ "sigma": timestep, "model_options": model_options, "input": x} cfg_result = fn(args) + if return_full: + return cfg_result, cond_pred, uncond_pred + return cfg_result @@ -333,8 +336,8 @@ def sampling_function(self, denoiser_params, cond_scale, cond_composition): for modifier in model_options.get('conditioning_modifiers', []): model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed) - denoised = sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options, seed) - return denoised + denoised, cond_pred, uncond_pred = sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options, seed, return_full=True) + return denoised, cond_pred, uncond_pred def sampling_prepare(unet, x): diff --git a/k_diffusion/external.py b/k_diffusion/external.py new file mode 100644 index 00000000..79b51cec --- /dev/null +++ b/k_diffusion/external.py @@ -0,0 +1,177 @@ +import math + +import torch +from torch import nn + +from . import sampling, utils + + +class VDenoiser(nn.Module): + """A v-diffusion-pytorch model wrapper for k-diffusion.""" + + def __init__(self, inner_model): + super().__init__() + self.inner_model = inner_model + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_skip, c_out, c_in + + def sigma_to_t(self, sigma): + return sigma.atan() / math.pi * 2 + + def t_to_sigma(self, t): + return (t * math.pi / 2).tan() + + def loss(self, input, noise, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + target = (input - c_skip * noised_input) / c_out + return (model_output - target).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip + + +class DiscreteSchedule(nn.Module): + """A mapping between continuous noise levels (sigmas) and a list of discrete noise + levels.""" + + def __init__(self, sigmas, quantize): + super().__init__() + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + self.quantize = quantize + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def get_sigmas(self, n=None): + if n is None: + return sampling.append_zero(self.sigmas.flip(0)) + t_max = len(self.sigmas) - 1 + t = torch.linspace(t_max, 0, n, device=self.sigmas.device) + return sampling.append_zero(self.t_to_sigma(t)) + + def sigma_to_t(self, sigma, quantize=None): + quantize = self.quantize if quantize is None else quantize + log_sigma = sigma.log() + dists = log_sigma - self.log_sigmas[:, None] + if quantize: + return dists.abs().argmin(dim=0).view(sigma.shape) + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + return t.view(sigma.shape) + + def t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp() + + +class DiscreteEpsDDPMDenoiser(DiscreteSchedule): + """A wrapper for discrete schedule DDPM models that output eps (the predicted + noise).""" + + def __init__(self, model, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) + self.inner_model = model + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_out = -sigma + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_out, c_in + + def get_eps(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + + def loss(self, input, noise, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + return (eps - noise).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + return input + eps * c_out + + +class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): + """A wrapper for OpenAI diffusion models.""" + + def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): + alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) + super().__init__(model, alphas_cumprod, quantize=quantize) + self.has_learned_sigmas = has_learned_sigmas + + def get_eps(self, *args, **kwargs): + model_output = self.inner_model(*args, **kwargs) + if self.has_learned_sigmas: + return model_output.chunk(2, dim=1)[0] + return model_output + + +class CompVisDenoiser(DiscreteEpsDDPMDenoiser): + """A wrapper for CompVis diffusion models.""" + + def __init__(self, model, quantize=False, device='cpu'): + super().__init__(model, model.alphas_cumprod, quantize=quantize) + + def get_eps(self, *args, **kwargs): + return self.inner_model.apply_model(*args, **kwargs) + + +class DiscreteVDDPMDenoiser(DiscreteSchedule): + """A wrapper for discrete schedule DDPM models that output v.""" + + def __init__(self, model, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) + self.inner_model = model + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_skip, c_out, c_in + + def get_v(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + + def loss(self, input, noise, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + target = (input - c_skip * noised_input) / c_out + return (model_output - target).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip + + +class CompVisVDenoiser(DiscreteVDDPMDenoiser): + """A wrapper for CompVis diffusion models that output v.""" + + def __init__(self, model, quantize=False, device='cpu'): + super().__init__(model, model.alphas_cumprod, quantize=quantize) + + def get_v(self, x, t, cond, **kwargs): + return self.inner_model.apply_model(x, t, cond) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py new file mode 100644 index 00000000..9f10e8d8 --- /dev/null +++ b/k_diffusion/sampling.py @@ -0,0 +1,702 @@ +import math + +from scipy import integrate +import torch +from torch import nn +from torchdiffeq import odeint +import torchsde +from tqdm.auto import trange, tqdm + +from . import utils + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + +def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): + """Constructs an exponential noise schedule.""" + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + return append_zero(sigmas) + + +def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'): + """Constructs an polynomial in log sigma noise schedule.""" + ramp = torch.linspace(1, 0, n, device=device) ** rho + sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) + return append_zero(sigmas) + + +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): + """Constructs a continuous VP noise schedule.""" + t = torch.linspace(1, eps_s, n, device=device) + sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) + return append_zero(sigmas) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / utils.append_dims(sigma, x.ndim) + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + if not eta: + return sigma_to, 0. + sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +def default_noise_sampler(x): + return lambda sigma, sigma_next: torch.randn_like(x) + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2 ** 63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +@torch.no_grad() +def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + +@torch.no_grad() +def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + if sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +@torch.no_grad() +def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + return x + + +@torch.no_grad() +def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + dt = sigmas[i + 1] - sigma_hat + x = x + d * dt + else: + # DPM-Solver-2 + sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp() + dt_1 = sigma_mid - sigma_hat + dt_2 = sigmas[i + 1] - sigma_hat + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + return x + + +@torch.no_grad() +def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + if sigma_down == 0: + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver-2 + sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() + dt_1 = sigma_mid - sigmas[i] + dt_2 = sigma_down - sigmas[i] + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): + prod = 1. + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + + +@torch.no_grad() +def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigmas_cpu = sigmas.detach().cpu().numpy() + ds = [] + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + d = to_d(x, sigmas[i], denoised) + ds.append(d) + if len(ds) > order: + ds.pop(0) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + cur_order = min(i + 1, order) + coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + return x + + +@torch.no_grad() +def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + v = torch.randint_like(x, 2) * 2 - 1 + fevals = 0 + def ode_fn(sigma, x): + nonlocal fevals + with torch.enable_grad(): + x = x[0].detach().requires_grad_() + denoised = model(x, sigma * s_in, **extra_args) + d = to_d(x, sigma, denoised) + fevals += 1 + grad = torch.autograd.grad((d * v).sum(), x)[0] + d_ll = (v * grad).flatten(1).sum(1) + return d.detach(), d_ll + x_min = x, x.new_zeros([x.shape[0]]) + t = x.new_tensor([sigma_min, sigma_max]) + sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') + latent, delta_ll = sol[0][-1], sol[1][-1] + ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + return ll_prior + delta_ll, {'fevals': fevals} + + +class PIDStepSizeController: + """A PID controller for ODE adaptive step size control.""" + def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): + self.h = h + self.b1 = (pcoeff + icoeff + dcoeff) / order + self.b2 = -(pcoeff + 2 * dcoeff) / order + self.b3 = dcoeff / order + self.accept_safety = accept_safety + self.eps = eps + self.errs = [] + + def limiter(self, x): + return 1 + math.atan(x - 1) + + def propose_step(self, error): + inv_error = 1 / (float(error) + self.eps) + if not self.errs: + self.errs = [inv_error, inv_error, inv_error] + self.errs[0] = inv_error + factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + factor = self.limiter(factor) + accept = factor >= self.accept_safety + if accept: + self.errs[2] = self.errs[1] + self.errs[1] = self.errs[0] + self.h *= factor + return accept + + +class DPMSolver(nn.Module): + """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" + + def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): + super().__init__() + self.model = model + self.extra_args = {} if extra_args is None else extra_args + self.eps_callback = eps_callback + self.info_callback = info_callback + + def t(self, sigma): + return -sigma.log() + + def sigma(self, t): + return t.neg().exp() + + def eps(self, eps_cache, key, x, t, *args, **kwargs): + if key in eps_cache: + return eps_cache[key], eps_cache + sigma = self.sigma(t) * x.new_ones([x.shape[0]]) + eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) + if self.eps_callback is not None: + self.eps_callback() + return eps, {key: eps, **eps_cache} + + def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + x_1 = x - self.sigma(t_next) * h.expm1() * eps + return x_1, eps_cache + + def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + s1 = t + r1 * h + u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps + eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) + x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + return x_2, eps_cache + + def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + s1 = t + r1 * h + s2 = t + r2 * h + u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps + eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) + u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) + eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) + x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + return x_3, eps_cache + + def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + if not t_end > t_start and eta: + raise ValueError('eta must be 0 for reverse sampling') + + m = math.floor(nfe / 3) + 1 + ts = torch.linspace(t_start, t_end, m + 1, device=x.device) + + if nfe % 3 == 0: + orders = [3] * (m - 2) + [2, 1] + else: + orders = [3] * (m - 1) + [nfe % 3] + + for i in range(len(orders)): + eps_cache = {} + t, t_next = ts[i], ts[i + 1] + if eta: + sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) + t_next_ = torch.minimum(t_end, self.t(sd)) + su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 + else: + t_next_, su = t_next, 0. + + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + denoised = x - self.sigma(t) * eps + if self.info_callback is not None: + self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) + + if orders[i] == 1: + x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) + elif orders[i] == 2: + x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) + else: + x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) + + x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next)) + + return x + + def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + if order not in {2, 3}: + raise ValueError('order should be 2 or 3') + forward = t_end > t_start + if not forward and eta: + raise ValueError('eta must be 0 for reverse sampling') + h_init = abs(h_init) * (1 if forward else -1) + atol = torch.tensor(atol) + rtol = torch.tensor(rtol) + s = t_start + x_prev = x + accept = True + pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) + info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} + + while s < t_end - 1e-5 if forward else s > t_end + 1e-5: + eps_cache = {} + t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) + if eta: + sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) + t_ = torch.minimum(t_end, self.t(sd)) + su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 + else: + t_, su = t, 0. + + eps, eps_cache = self.eps(eps_cache, 'eps', x, s) + denoised = x - self.sigma(s) * eps + + if order == 2: + x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) + else: + x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) + delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) + error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 + accept = pid.propose_step(error) + if accept: + x_prev = x_low + x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) + s = t + info['n_accept'] += 1 + else: + info['n_reject'] += 1 + info['nfe'] += order + info['steps'] += 1 + + if self.info_callback is not None: + self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) + + return x, info + + +@torch.no_grad() +def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): + """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError('sigma_min and sigma_max must not be 0') + with tqdm(total=n, disable=disable) as pbar: + dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) + if callback is not None: + dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) + return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler) + + +@torch.no_grad() +def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): + """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError('sigma_min and sigma_max must not be 0') + with tqdm(disable=disable) as pbar: + dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) + if callback is not None: + dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) + x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler) + if return_info: + return x, info + return x + + +@torch.no_grad() +def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigma_down == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++(2S) + t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) + r = 1 / 2 + h = t_next - t + s = t + r * h + x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2 + # Noise addition + if sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +@torch.no_grad() +def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): + """DPM-Solver++ (stochastic).""" + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigmas[i + 1] - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++ + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + s = t + h * r + fac = 1 / (2 * r) + + # Step 1 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) + s_ = t_fn(sd) + x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised + x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + + # Step 2 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) + t_next_ = t_fn(sd) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d + x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su + return x + + +@torch.no_grad() +def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): + """DPM-Solver++(2M).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + old_denoised = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + if old_denoised is None or sigmas[i + 1] == 0: + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised + else: + h_last = t - t_fn(sigmas[i - 1]) + r = h_last / h + denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d + old_denoised = denoised + return x + + +@torch.no_grad() +def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + """DPM-Solver++(2M) SDE.""" + + if solver_type not in {'heun', 'midpoint'}: + raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + + if eta: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x + + +@torch.no_grad() +def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """DPM-Solver++(3M) SDE.""" + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + denoised_1, denoised_2 = None, None + h_1, h_2 = None, None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + h_eta = h * (eta + 1) + + x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised + + if h_2 is not None: + r0 = h_1 / h + r1 = h_2 / h + d1_0 = (denoised - denoised_1) / r0 + d1_1 = (denoised_1 - denoised_2) / r1 + d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1) + d2 = (d1_0 - d1_1) / (r0 + r1) + phi_2 = h_eta.neg().expm1() / h_eta + 1 + phi_3 = phi_2 / h_eta - 0.5 + x = x + phi_2 * d1 - phi_3 * d2 + elif h_1 is not None: + r = h_1 / h + d = (denoised - denoised_1) / r + phi_2 = h_eta.neg().expm1() / h_eta + 1 + x = x + phi_2 * d + + if eta: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise + + denoised_1, denoised_2 = denoised, denoised_1 + h_1, h_2 = h, h_1 + return x diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py new file mode 100644 index 00000000..946f2da5 --- /dev/null +++ b/k_diffusion/utils.py @@ -0,0 +1,458 @@ +from contextlib import contextmanager +import hashlib +import math +from pathlib import Path +import shutil +import threading +import time +import urllib +import warnings + +from PIL import Image +import safetensors +import torch +from torch import nn, optim +from torch.utils import data +from torchvision.transforms import functional as TF + + +def from_pil_image(x): + """Converts from a PIL image to a tensor.""" + x = TF.to_tensor(x) + if x.ndim == 2: + x = x[..., None] + return x * 2 - 1 + + +def to_pil_image(x): + """Converts from a tensor to a PIL image.""" + if x.ndim == 4: + assert x.shape[0] == 1 + x = x[0] + if x.shape[0] == 1: + x = x[0] + return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2) + + +def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'): + """Apply passed in transforms for HuggingFace Datasets.""" + images = [transform(image.convert(mode)) for image in examples[image_key]] + return {image_key: images} + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def n_params(module): + """Returns the number of trainable parameters in a module.""" + return sum(p.numel() for p in module.parameters()) + + +def download_file(path, url, digest=None): + """Downloads a file if it does not exist, optionally checking its SHA-256 hash.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + if not path.exists(): + with urllib.request.urlopen(url) as response, open(path, 'wb') as f: + shutil.copyfileobj(response, f) + if digest is not None: + file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest() + if digest != file_digest: + raise OSError(f'hash of {path} (url: {url}) failed to validate') + return path + + +@contextmanager +def train_mode(model, mode=True): + """A context manager that places a model into training mode and restores + the previous mode on exit.""" + modes = [module.training for module in model.modules()] + try: + yield model.train(mode) + finally: + for i, module in enumerate(model.modules()): + module.training = modes[i] + + +def eval_mode(model): + """A context manager that places a model into evaluation mode and restores + the previous mode on exit.""" + return train_mode(model, False) + + +@torch.no_grad() +def ema_update(model, averaged_model, decay): + """Incorporates updated model parameters into an exponential moving averaged + version of a model. It should be called after each optimizer step.""" + model_params = dict(model.named_parameters()) + averaged_params = dict(averaged_model.named_parameters()) + assert model_params.keys() == averaged_params.keys() + + for name, param in model_params.items(): + averaged_params[name].lerp_(param, 1 - decay) + + model_buffers = dict(model.named_buffers()) + averaged_buffers = dict(averaged_model.named_buffers()) + assert model_buffers.keys() == averaged_buffers.keys() + + for name, buf in model_buffers.items(): + averaged_buffers[name].copy_(buf) + + +class EMAWarmup: + """Implements an EMA warmup using an inverse decay schedule. + If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are + good values for models you plan to train for a million or more steps (reaches decay + factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models + you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at + 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 1. + min_value (float): The minimum EMA decay rate. Default: 0. + max_value (float): The maximum EMA decay rate. Default: 1. + start_at (int): The epoch to start averaging at. Default: 0. + last_epoch (int): The index of last epoch. Default: 0. + """ + + def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, + last_epoch=0): + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + self.start_at = start_at + self.last_epoch = last_epoch + + def state_dict(self): + """Returns the state of the class as a :class:`dict`.""" + return dict(self.__dict__.items()) + + def load_state_dict(self, state_dict): + """Loads the class's state. + Args: + state_dict (dict): scaler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_value(self): + """Gets the current EMA decay rate.""" + epoch = max(0, self.last_epoch - self.start_at) + value = 1 - (1 + epoch / self.inv_gamma) ** -self.power + return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) + + def step(self): + """Updates the step count.""" + self.last_epoch += 1 + + +class InverseLR(optim.lr_scheduler._LRScheduler): + """Implements an inverse decay learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. + inv_gamma is the number of steps/epochs required for the learning rate to decay to + (1 / 2)**power of its original value. + Args: + optimizer (Optimizer): Wrapped optimizer. + inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. + power (float): Exponential factor of learning rate decay. Default: 1. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + min_lr (float): The minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0., + last_epoch=-1, verbose=False): + self.inv_gamma = inv_gamma + self.power = power + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + self.min_lr = min_lr + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power + return [warmup * max(self.min_lr, base_lr * lr_mult) + for base_lr in self.base_lrs] + + +class ExponentialLR(optim.lr_scheduler._LRScheduler): + """Implements an exponential learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate + continuously by decay (default 0.5) every num_steps steps. + Args: + optimizer (Optimizer): Wrapped optimizer. + num_steps (float): The number of steps to decay the learning rate by decay in. + decay (float): The factor by which to decay the learning rate every num_steps + steps. Default: 0.5. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + min_lr (float): The minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0., + last_epoch=-1, verbose=False): + self.num_steps = num_steps + self.decay = decay + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + self.min_lr = min_lr + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch + return [warmup * max(self.min_lr, base_lr * lr_mult) + for base_lr in self.base_lrs] + + +class ConstantLRWithWarmup(optim.lr_scheduler._LRScheduler): + """Implements a constant learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, warmup=0., last_epoch=-1, verbose=False): + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + return [warmup * base_lr for base_lr in self.base_lrs] + + +def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None): + """Draws stratified samples from a uniform distribution.""" + if groups <= 0: + raise ValueError(f"groups must be positive, got {groups}") + if group < 0 or group >= groups: + raise ValueError(f"group must be in [0, {groups})") + n = shape[-1] * groups + offsets = torch.arange(group, n, groups, dtype=dtype, device=device) + u = torch.rand(shape, dtype=dtype, device=device) + return (offsets + u) / n + + +stratified_settings = threading.local() + + +@contextmanager +def enable_stratified(group=0, groups=1, disable=False): + """A context manager that enables stratified sampling.""" + try: + stratified_settings.disable = disable + stratified_settings.group = group + stratified_settings.groups = groups + yield + finally: + del stratified_settings.disable + del stratified_settings.group + del stratified_settings.groups + + +@contextmanager +def enable_stratified_accelerate(accelerator, disable=False): + """A context manager that enables stratified sampling, distributing the strata across + all processes and gradient accumulation steps using settings from Hugging Face Accelerate.""" + try: + rank = accelerator.process_index + world_size = accelerator.num_processes + acc_steps = accelerator.gradient_state.num_steps + acc_step = accelerator.step % acc_steps + group = rank * acc_steps + acc_step + groups = world_size * acc_steps + with enable_stratified(group, groups, disable=disable): + yield + finally: + pass + + +def stratified_with_settings(shape, dtype=None, device=None): + """Draws stratified samples from a uniform distribution, using settings from a context + manager.""" + if not hasattr(stratified_settings, 'disable') or stratified_settings.disable: + return torch.rand(shape, dtype=dtype, device=device) + return stratified_uniform( + shape, stratified_settings.group, stratified_settings.groups, dtype=dtype, device=device + ) + + +def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): + """Draws samples from an lognormal distribution.""" + u = stratified_with_settings(shape, device=device, dtype=dtype) * (1 - 2e-7) + 1e-7 + return torch.distributions.Normal(loc, scale).icdf(u).exp() + + +def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): + """Draws samples from an optionally truncated log-logistic distribution.""" + min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) + max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) + min_cdf = min_value.log().sub(loc).div(scale).sigmoid() + max_cdf = max_value.log().sub(loc).div(scale).sigmoid() + u = stratified_with_settings(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf + return u.logit().mul(scale).add(loc).exp().to(dtype) + + +def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): + """Draws samples from an log-uniform distribution.""" + min_value = math.log(min_value) + max_value = math.log(max_value) + return (stratified_with_settings(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() + + +def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): + """Draws samples from a truncated v-diffusion training timestep distribution.""" + min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi + max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi + u = stratified_with_settings(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf + return torch.tan(u * math.pi / 2) * sigma_data + + +def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32): + """Draws samples from an interpolated cosine timestep distribution (from simple diffusion).""" + + def logsnr_schedule_cosine(t, logsnr_min, logsnr_max): + t_min = math.atan(math.exp(-0.5 * logsnr_max)) + t_max = math.atan(math.exp(-0.5 * logsnr_min)) + return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) + + def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max): + shift = 2 * math.log(noise_d / image_d) + return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift + + def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max): + logsnr_low = logsnr_schedule_cosine_shifted(t, image_d, noise_d_low, logsnr_min, logsnr_max) + logsnr_high = logsnr_schedule_cosine_shifted(t, image_d, noise_d_high, logsnr_min, logsnr_max) + return torch.lerp(logsnr_low, logsnr_high, t) + + logsnr_min = -2 * math.log(min_value / sigma_data) + logsnr_max = -2 * math.log(max_value / sigma_data) + u = stratified_with_settings(shape, device=device, dtype=dtype) + logsnr = logsnr_schedule_cosine_interpolated(u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max) + return torch.exp(-logsnr / 2) * sigma_data + + +def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32): + """Draws samples from a split lognormal distribution.""" + n = torch.randn(shape, device=device, dtype=dtype).abs() + u = torch.rand(shape, device=device, dtype=dtype) + n_left = n * -scale_1 + loc + n_right = n * scale_2 + loc + ratio = scale_1 / (scale_1 + scale_2) + return torch.where(u < ratio, n_left, n_right).exp() + + +class FolderOfImages(data.Dataset): + """Recursively finds all images in a directory. It does not support + classes/targets.""" + + IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'} + + def __init__(self, root, transform=None): + super().__init__() + self.root = Path(root) + self.transform = nn.Identity() if transform is None else transform + self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS) + + def __repr__(self): + return f'FolderOfImages(root="{self.root}", len: {len(self)})' + + def __len__(self): + return len(self.paths) + + def __getitem__(self, key): + path = self.paths[key] + with open(path, 'rb') as f: + image = Image.open(f).convert('RGB') + image = self.transform(image) + return image, + + +class CSVLogger: + def __init__(self, filename, columns): + self.filename = Path(filename) + self.columns = columns + if self.filename.exists(): + self.file = open(self.filename, 'a') + else: + self.file = open(self.filename, 'w') + self.write(*self.columns) + + def write(self, *args): + print(*args, sep=',', file=self.file, flush=True) + + +@contextmanager +def tf32_mode(cudnn=None, matmul=None): + """A context manager that sets whether TF32 is allowed on cuDNN or matmul.""" + cudnn_old = torch.backends.cudnn.allow_tf32 + matmul_old = torch.backends.cuda.matmul.allow_tf32 + try: + if cudnn is not None: + torch.backends.cudnn.allow_tf32 = cudnn + if matmul is not None: + torch.backends.cuda.matmul.allow_tf32 = matmul + yield + finally: + if cudnn is not None: + torch.backends.cudnn.allow_tf32 = cudnn_old + if matmul is not None: + torch.backends.cuda.matmul.allow_tf32 = matmul_old + + +def get_safetensors_metadata(path): + """Retrieves the metadata from a safetensors file.""" + return safetensors.safe_open(path, "pt").metadata() + + +def ema_update_dict(values, updates, decay): + for k, v in updates.items(): + if k not in values: + values[k] = v + else: + values[k] *= decay + values[k] += (1 - decay) * v + return values diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 74468550..6cca114b 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -393,14 +393,14 @@ def prepare_environment(): assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git") # stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") # stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") - k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') + # k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917") # stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") # stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") - k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") + # k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "3f96b28763515dbe609792135df3615a440c66dc") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -458,7 +458,7 @@ def prepare_environment(): git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash) # git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) # git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) - git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) + # git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) diff --git a/modules/paths.py b/modules/paths.py index 83bbbc79..d857c96f 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -9,7 +9,7 @@ sd_path = os.path.dirname(__file__) path_dirs = [ (os.path.join(sd_path, '../repositories/BLIP'), 'models/blip.py', 'BLIP', []), - (os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), + # (os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../repositories/huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []), ] diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 0934ccfe..b329a6c9 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -59,6 +59,9 @@ class CFGDenoiser(torch.nn.Module): self.model_wrap = None self.p = None + self.need_last_noise_uncond = False + self.last_noise_uncond = None + # Backward Compatibility self.mask_before_denoising = False @@ -179,7 +182,10 @@ class CFGDenoiser(torch.nn.Module): denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self) cfg_denoiser_callback(denoiser_params) - denoised = sampling_function(self, denoiser_params=denoiser_params, cond_scale=cond_scale, cond_composition=cond_composition) + denoised, cond_pred, uncond_pred = sampling_function(self, denoiser_params=denoiser_params, cond_scale=cond_scale, cond_composition=cond_composition) + + if self.need_last_noise_uncond: + self.last_noise_uncond = (x - uncond_pred) / sigma[:, None, None, None] if self.mask is not None: blended_latent = denoised * self.nmask + self.init_latent * self.mask diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index f0b88c95..146e9c3a 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,6 +1,7 @@ import torch import inspect import k_diffusion.sampling +import k_diffusion.external from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401 from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback @@ -55,13 +56,11 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): @property def inner_model(self): if self.model_wrap is None: - denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None) - - if denoiser_constructor is not None: - self.model_wrap = denoiser_constructor() - else: - denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + self.model_wrap = k_diffusion.external.DiscreteSchedule( + sigmas=shared.sd_model.forge_objects.unet.model.predictor.sigmas, + quantize=shared.opts.enable_quantization + ) + self.model_wrap.inner_model = shared.sd_model return self.model_wrap diff --git a/modules/sd_samplers_lcm.py b/modules/sd_samplers_lcm.py index df4d97b1..4ac57ef0 100644 --- a/modules/sd_samplers_lcm.py +++ b/modules/sd_samplers_lcm.py @@ -13,9 +13,10 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser): original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM) self.skip_steps = timesteps // original_timesteps - alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32) + alphas_cumprod = 1.0 / (model.forge_objects.unet.model.predictor.sigmas ** 2.0 + 1.0) + alphas_cumprod_valid = torch.zeros(original_timesteps, dtype=torch.float32) for x in range(original_timesteps): - alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps] + alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] super().__init__(model, alphas_cumprod_valid, quantize=None) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 08956497..3e410f4d 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -28,31 +28,18 @@ class CompVisTimestepsDenoiser(torch.nn.Module): def __init__(self, model, *args, **kwargs): super().__init__(*args, **kwargs) self.inner_model = model + self.inner_model.alphas_cumprod = 1.0 / (self.inner_model.forge_objects.unet.model.predictor.sigmas ** 2.0 + 1.0) def forward(self, input, timesteps, **kwargs): return self.inner_model.apply_model(input, timesteps, **kwargs) -class CompVisTimestepsVDenoiser(torch.nn.Module): - def __init__(self, model, *args, **kwargs): - super().__init__(*args, **kwargs) - self.inner_model = model - - def predict_eps_from_z_and_v(self, x_t, t, v): - return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t - - def forward(self, input, timesteps, **kwargs): - model_output = self.inner_model.apply_model(input, timesteps, **kwargs) - e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output) - return e_t - - class CFGDenoiserTimesteps(CFGDenoiser): def __init__(self, sampler): super().__init__(sampler) - self.alphas = shared.sd_model.alphas_cumprod + self.alphas = 1.0 / (shared.sd_model.forge_objects.unet.model.predictor.sigmas ** 2.0 + 1.0) self.classic_ddim_eps_estimation = True def get_pred_x0(self, x_in, x_out, sigma): @@ -69,8 +56,7 @@ class CFGDenoiserTimesteps(CFGDenoiser): @property def inner_model(self): if self.model_wrap is None: - denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser - self.model_wrap = denoiser(shared.sd_model) + self.model_wrap = CompVisTimestepsDenoiser(shared.sd_model) return self.model_wrap