diff --git a/backend/modules/k_diffusion_extra.py b/backend/modules/k_diffusion_extra.py new file mode 100644 index 00000000..4d9f4693 --- /dev/null +++ b/backend/modules/k_diffusion_extra.py @@ -0,0 +1,39 @@ +# Only include samplers that are not already in A1111 + +import torch +from tqdm import trange + + +def default_noise_sampler(x): + return lambda sigma, sigma_next: torch.randn_like(x) + + +def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler) + if sigmas[i + 1] != 0: + x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0) + return x + + +def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): + alpha_cumprod = 1 / ((sigma * sigma) + 1) + alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1) + alpha = (alpha_cumprod / alpha_cumprod_prev) + + mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) + if sigma_prev > 0: + mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) + return mu + + +@torch.no_grad() +def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): + return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) diff --git a/modules_forge/forge_alter_samplers.py b/modules_forge/forge_alter_samplers.py index 1f474c67..15cbe2f0 100644 --- a/modules_forge/forge_alter_samplers.py +++ b/modules_forge/forge_alter_samplers.py @@ -1,23 +1,22 @@ -# from modules import sd_samplers_kdiffusion, sd_samplers_common -# from ldm_patched.k_diffusion import sampling as k_diffusion_sampling -# -# -# class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): -# def __init__(self, sd_model, sampler_name): -# self.sampler_name = sampler_name -# self.unet = sd_model.forge_objects.unet -# sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) -# super().__init__(sampler_function, sd_model, None) -# -# -# def build_constructor(sampler_name): -# def constructor(m): -# return AlterSampler(m, sampler_name) -# -# return constructor -# -# +from modules import sd_samplers_kdiffusion, sd_samplers_common +from backend.modules import k_diffusion_extra + + +class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): + def __init__(self, sd_model, sampler_name): + self.sampler_name = sampler_name + self.unet = sd_model.forge_objects.unet + sampler_function = getattr(k_diffusion_extra, "sample_{}".format(sampler_name)) + super().__init__(sampler_function, sd_model, None) + + +def build_constructor(sampler_name): + def constructor(m): + return AlterSampler(m, sampler_name) + + return constructor + samplers_data_alter = [ - # sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}), + sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}), ]