Add samplers: HeunPP2, IPDNM, IPNDM_V, DEIS

Pending: CFG++ Samplers, ODE Samplers
The latter is probably easy to implement, the former needs modifications in sd_samplers_cfg_denoiser.py
This commit is contained in:
Panchovix
2024-08-19 20:48:41 -04:00
parent 9bc2d04ca9
commit 2fc1708a59
3 changed files with 336 additions and 0 deletions

View File

@@ -6,6 +6,7 @@ from torch import nn
from torchdiffeq import odeint
import torchsde
from tqdm.auto import trange, tqdm
from k_diffusion import deis
from . import utils
@@ -700,3 +701,213 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
denoised_1, denoised_2 = denoised, denoised_1
h_1, h_2 = h, h_1
return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_end = sigmas[-1]
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == s_end:
# Euler method
x = x + d * dt
elif sigmas[i + 2] == s_end:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
w = 2 * sigmas[0]
w2 = sigmas[i+1]/w
w1 = 1 - w2
d_prime = d * w1 + d_2 * w2
x = x + d_prime * dt
else:
# Heun++
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
dt_2 = sigmas[i + 2] - sigmas[i + 1]
x_3 = x_2 + d_2 * dt_2
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
w = 3 * sigmas[0]
w2 = sigmas[i + 1] / w
w3 = sigmas[i + 2] / w
w1 = 1 - w2 - w3
d_prime = w1 * d + w2 * d_2 + w3 * d_3
x = x + d_prime * dt
return x
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
x_next = x
buffer_model = []
for i in trange(len(sigmas) - 1, disable=disable):
t_cur = sigmas[i]
t_next = sigmas[i + 1]
x_cur = x_next
denoised = model(x_cur, t_cur * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
elif order == 3: # Use two history points.
x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
elif order == 4: # Use three history points.
x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
if len(buffer_model) == max_order - 1:
for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1]
buffer_model[-1] = d_cur
else:
buffer_model.append(d_cur)
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
x_next = x
t_steps = sigmas
buffer_model = []
for i in trange(len(sigmas) - 1, disable=disable):
t_cur = sigmas[i]
t_next = sigmas[i + 1]
x_cur = x_next
denoised = model(x_cur, t_cur * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1])
coeff1 = (2 + (h_n / h_n_1)) / 2
coeff2 = -(h_n / h_n_1) / 2
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
elif order == 3: # Use two history points.
h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1])
h_n_2 = (t_steps[i-1] - t_steps[i-2])
temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
coeff3 = temp * h_n_1 / h_n_2
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
elif order == 4: # Use three history points.
h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1])
h_n_2 = (t_steps[i-1] - t_steps[i-2])
h_n_3 = (t_steps[i-2] - t_steps[i-3])
temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3])
if len(buffer_model) == max_order - 1:
for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1]
buffer_model[-1] = d_cur.detach()
else:
buffer_model.append(d_cur.detach())
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
@torch.no_grad()
def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
x_next = x
t_steps = sigmas
coeff_list = deis.get_deis_coeff_list(t_steps, max_order, deis_mode=deis_mode)
buffer_model = []
for i in trange(len(sigmas) - 1, disable=disable):
t_cur = sigmas[i]
t_next = sigmas[i + 1]
x_cur = x_next
denoised = model(x_cur, t_cur * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if t_next <= 0:
order = 1
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
coeff_cur, coeff_prev1 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
elif order == 3: # Use two history points.
coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
elif order == 4: # Use three history points.
coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
if len(buffer_model) == max_order - 1:
for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1]
buffer_model[-1] = d_cur.detach()
else:
buffer_model.append(d_cur.detach())
return x_next