mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
Add samplers: HeunPP2, IPDNM, IPNDM_V, DEIS
Pending: CFG++ Samplers, ODE Samplers The latter is probably easy to implement, the former needs modifications in sd_samplers_cfg_denoiser.py
This commit is contained in:
121
k_diffusion/deis.py
Normal file
121
k_diffusion/deis.py
Normal file
@@ -0,0 +1,121 @@
|
||||
#Taken from: https://github.com/zju-pi/diff-sampler/blob/main/gits-main/solver_utils.py
|
||||
#under Apache 2 license
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# A pytorch reimplementation of DEIS (https://github.com/qsh-zh/deis).
|
||||
#############################
|
||||
### Utils for DEIS solver ###
|
||||
#############################
|
||||
#----------------------------------------------------------------------------
|
||||
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
|
||||
|
||||
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
|
||||
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
|
||||
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
||||
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
|
||||
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
|
||||
t_steps = vp_sigma_inv(vp_beta_d.clone().detach().cpu(), vp_beta_min.clone().detach().cpu())(edm_steps.clone().detach().cpu())
|
||||
return t_steps, vp_beta_min, vp_beta_d + vp_beta_min
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def cal_poly(prev_t, j, taus):
|
||||
poly = 1
|
||||
for k in range(prev_t.shape[0]):
|
||||
if k == j:
|
||||
continue
|
||||
poly *= (taus - prev_t[k]) / (prev_t[j] - prev_t[k])
|
||||
return poly
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Transfer from t to alpha_t.
|
||||
|
||||
def t2alpha_fn(beta_0, beta_1, t):
|
||||
return torch.exp(-0.5 * t ** 2 * (beta_1 - beta_0) - t * beta_0)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def cal_intergrand(beta_0, beta_1, taus):
|
||||
with torch.inference_mode(mode=False):
|
||||
taus = taus.clone()
|
||||
beta_0 = beta_0.clone()
|
||||
beta_1 = beta_1.clone()
|
||||
with torch.enable_grad():
|
||||
taus.requires_grad_(True)
|
||||
alpha = t2alpha_fn(beta_0, beta_1, taus)
|
||||
log_alpha = alpha.log()
|
||||
log_alpha.sum().backward()
|
||||
d_log_alpha_dtau = taus.grad
|
||||
integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
|
||||
return integrand
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def get_deis_coeff_list(t_steps, max_order, N=10000, deis_mode='tab'):
|
||||
"""
|
||||
Get the coefficient list for DEIS sampling.
|
||||
|
||||
Args:
|
||||
t_steps: A pytorch tensor. The time steps for sampling.
|
||||
max_order: A `int`. Maximum order of the solver. 1 <= max_order <= 4
|
||||
N: A `int`. Use how many points to perform the numerical integration when deis_mode=='tab'.
|
||||
deis_mode: A `str`. Select between 'tab' and 'rhoab'. Type of DEIS.
|
||||
Returns:
|
||||
A pytorch tensor. A batch of generated samples or sampling trajectories if return_inters=True.
|
||||
"""
|
||||
if deis_mode == 'tab':
|
||||
t_steps, beta_0, beta_1 = edm2t(t_steps)
|
||||
C = []
|
||||
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
|
||||
order = min(i+1, max_order)
|
||||
if order == 1:
|
||||
C.append([])
|
||||
else:
|
||||
taus = torch.linspace(t_cur, t_next, N) # split the interval for integral appximation
|
||||
dtau = (t_next - t_cur) / N
|
||||
prev_t = t_steps[[i - k for k in range(order)]]
|
||||
coeff_temp = []
|
||||
integrand = cal_intergrand(beta_0, beta_1, taus)
|
||||
for j in range(order):
|
||||
poly = cal_poly(prev_t, j, taus)
|
||||
coeff_temp.append(torch.sum(integrand * poly) * dtau)
|
||||
C.append(coeff_temp)
|
||||
|
||||
elif deis_mode == 'rhoab':
|
||||
# Analytical solution, second order
|
||||
def get_def_intergral_2(a, b, start, end, c):
|
||||
coeff = (end**3 - start**3) / 3 - (end**2 - start**2) * (a + b) / 2 + (end - start) * a * b
|
||||
return coeff / ((c - a) * (c - b))
|
||||
|
||||
# Analytical solution, third order
|
||||
def get_def_intergral_3(a, b, c, start, end, d):
|
||||
coeff = (end**4 - start**4) / 4 - (end**3 - start**3) * (a + b + c) / 3 \
|
||||
+ (end**2 - start**2) * (a*b + a*c + b*c) / 2 - (end - start) * a * b * c
|
||||
return coeff / ((d - a) * (d - b) * (d - c))
|
||||
|
||||
C = []
|
||||
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
|
||||
order = min(i, max_order)
|
||||
if order == 0:
|
||||
C.append([])
|
||||
else:
|
||||
prev_t = t_steps[[i - k for k in range(order+1)]]
|
||||
if order == 1:
|
||||
coeff_cur = ((t_next - prev_t[1])**2 - (t_cur - prev_t[1])**2) / (2 * (t_cur - prev_t[1]))
|
||||
coeff_prev1 = (t_next - t_cur)**2 / (2 * (prev_t[1] - t_cur))
|
||||
coeff_temp = [coeff_cur, coeff_prev1]
|
||||
elif order == 2:
|
||||
coeff_cur = get_def_intergral_2(prev_t[1], prev_t[2], t_cur, t_next, t_cur)
|
||||
coeff_prev1 = get_def_intergral_2(t_cur, prev_t[2], t_cur, t_next, prev_t[1])
|
||||
coeff_prev2 = get_def_intergral_2(t_cur, prev_t[1], t_cur, t_next, prev_t[2])
|
||||
coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2]
|
||||
elif order == 3:
|
||||
coeff_cur = get_def_intergral_3(prev_t[1], prev_t[2], prev_t[3], t_cur, t_next, t_cur)
|
||||
coeff_prev1 = get_def_intergral_3(t_cur, prev_t[2], prev_t[3], t_cur, t_next, prev_t[1])
|
||||
coeff_prev2 = get_def_intergral_3(t_cur, prev_t[1], prev_t[3], t_cur, t_next, prev_t[2])
|
||||
coeff_prev3 = get_def_intergral_3(t_cur, prev_t[1], prev_t[2], t_cur, t_next, prev_t[3])
|
||||
coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3]
|
||||
C.append(coeff_temp)
|
||||
return C
|
||||
|
||||
@@ -6,6 +6,7 @@ from torch import nn
|
||||
from torchdiffeq import odeint
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
from k_diffusion import deis
|
||||
|
||||
from . import utils
|
||||
|
||||
@@ -700,3 +701,213 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
denoised_1, denoised_2 = denoised, denoised_1
|
||||
h_1, h_2 = h, h_1
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
s_end = sigmas[-1]
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == s_end:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
elif sigmas[i + 2] == s_end:
|
||||
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
|
||||
w = 2 * sigmas[0]
|
||||
w2 = sigmas[i+1]/w
|
||||
w1 = 1 - w2
|
||||
|
||||
d_prime = d * w1 + d_2 * w2
|
||||
|
||||
|
||||
x = x + d_prime * dt
|
||||
|
||||
else:
|
||||
# Heun++
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
dt_2 = sigmas[i + 2] - sigmas[i + 1]
|
||||
|
||||
x_3 = x_2 + d_2 * dt_2
|
||||
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
|
||||
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
|
||||
|
||||
w = 3 * sigmas[0]
|
||||
w2 = sigmas[i + 1] / w
|
||||
w3 = sigmas[i + 2] / w
|
||||
w1 = 1 - w2 - w3
|
||||
|
||||
d_prime = w1 * d + w2 * d_2 + w3 * d_3
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
x_next = x
|
||||
|
||||
buffer_model = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
t_cur = sigmas[i]
|
||||
t_next = sigmas[i + 1]
|
||||
|
||||
x_cur = x_next
|
||||
|
||||
denoised = model(x_cur, t_cur * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
|
||||
elif order == 3: # Use two history points.
|
||||
x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
|
||||
elif order == 4: # Use three history points.
|
||||
x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
|
||||
|
||||
if len(buffer_model) == max_order - 1:
|
||||
for k in range(max_order - 2):
|
||||
buffer_model[k] = buffer_model[k+1]
|
||||
buffer_model[-1] = d_cur
|
||||
else:
|
||||
buffer_model.append(d_cur)
|
||||
|
||||
return x_next
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
x_next = x
|
||||
t_steps = sigmas
|
||||
|
||||
buffer_model = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
t_cur = sigmas[i]
|
||||
t_next = sigmas[i + 1]
|
||||
|
||||
x_cur = x_next
|
||||
|
||||
denoised = model(x_cur, t_cur * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
h_n = (t_next - t_cur)
|
||||
h_n_1 = (t_cur - t_steps[i-1])
|
||||
coeff1 = (2 + (h_n / h_n_1)) / 2
|
||||
coeff2 = -(h_n / h_n_1) / 2
|
||||
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
|
||||
elif order == 3: # Use two history points.
|
||||
h_n = (t_next - t_cur)
|
||||
h_n_1 = (t_cur - t_steps[i-1])
|
||||
h_n_2 = (t_steps[i-1] - t_steps[i-2])
|
||||
temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
||||
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
|
||||
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
|
||||
coeff3 = temp * h_n_1 / h_n_2
|
||||
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
|
||||
elif order == 4: # Use three history points.
|
||||
h_n = (t_next - t_cur)
|
||||
h_n_1 = (t_cur - t_steps[i-1])
|
||||
h_n_2 = (t_steps[i-1] - t_steps[i-2])
|
||||
h_n_3 = (t_steps[i-2] - t_steps[i-3])
|
||||
temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
||||
temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
|
||||
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
|
||||
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
|
||||
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
|
||||
coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
|
||||
coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2
|
||||
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3])
|
||||
|
||||
if len(buffer_model) == max_order - 1:
|
||||
for k in range(max_order - 2):
|
||||
buffer_model[k] = buffer_model[k+1]
|
||||
buffer_model[-1] = d_cur.detach()
|
||||
else:
|
||||
buffer_model.append(d_cur.detach())
|
||||
|
||||
return x_next
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
@torch.no_grad()
|
||||
def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
x_next = x
|
||||
t_steps = sigmas
|
||||
|
||||
coeff_list = deis.get_deis_coeff_list(t_steps, max_order, deis_mode=deis_mode)
|
||||
|
||||
buffer_model = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
t_cur = sigmas[i]
|
||||
t_next = sigmas[i + 1]
|
||||
|
||||
x_cur = x_next
|
||||
|
||||
denoised = model(x_cur, t_cur * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if t_next <= 0:
|
||||
order = 1
|
||||
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
coeff_cur, coeff_prev1 = coeff_list[i]
|
||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
|
||||
elif order == 3: # Use two history points.
|
||||
coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
|
||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
|
||||
elif order == 4: # Use three history points.
|
||||
coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
|
||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
|
||||
|
||||
if len(buffer_model) == max_order - 1:
|
||||
for k in range(max_order - 2):
|
||||
buffer_model[k] = buffer_model[k+1]
|
||||
buffer_model[-1] = d_cur.detach()
|
||||
else:
|
||||
buffer_model.append(d_cur.detach())
|
||||
|
||||
return x_next
|
||||
Reference in New Issue
Block a user