mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
add sampler
This commit is contained in:
39
backend/modules/k_diffusion_extra.py
Normal file
39
backend/modules/k_diffusion_extra.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Only include samplers that are not already in A1111
|
||||
|
||||
import torch
|
||||
from tqdm import trange
|
||||
|
||||
|
||||
def default_noise_sampler(x):
|
||||
return lambda sigma, sigma_next: torch.randn_like(x)
|
||||
|
||||
|
||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
|
||||
if sigmas[i + 1] != 0:
|
||||
x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
|
||||
return x
|
||||
|
||||
|
||||
def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
||||
alpha_cumprod = 1 / ((sigma * sigma) + 1)
|
||||
alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
|
||||
alpha = (alpha_cumprod / alpha_cumprod_prev)
|
||||
|
||||
mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
|
||||
if sigma_prev > 0:
|
||||
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
||||
return mu
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
||||
@@ -1,23 +1,22 @@
|
||||
# from modules import sd_samplers_kdiffusion, sd_samplers_common
|
||||
# from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
|
||||
#
|
||||
#
|
||||
# class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
||||
# def __init__(self, sd_model, sampler_name):
|
||||
# self.sampler_name = sampler_name
|
||||
# self.unet = sd_model.forge_objects.unet
|
||||
# sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
||||
# super().__init__(sampler_function, sd_model, None)
|
||||
#
|
||||
#
|
||||
# def build_constructor(sampler_name):
|
||||
# def constructor(m):
|
||||
# return AlterSampler(m, sampler_name)
|
||||
#
|
||||
# return constructor
|
||||
#
|
||||
#
|
||||
from modules import sd_samplers_kdiffusion, sd_samplers_common
|
||||
from backend.modules import k_diffusion_extra
|
||||
|
||||
|
||||
class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
||||
def __init__(self, sd_model, sampler_name):
|
||||
self.sampler_name = sampler_name
|
||||
self.unet = sd_model.forge_objects.unet
|
||||
sampler_function = getattr(k_diffusion_extra, "sample_{}".format(sampler_name))
|
||||
super().__init__(sampler_function, sd_model, None)
|
||||
|
||||
|
||||
def build_constructor(sampler_name):
|
||||
def constructor(m):
|
||||
return AlterSampler(m, sampler_name)
|
||||
|
||||
return constructor
|
||||
|
||||
|
||||
samplers_data_alter = [
|
||||
# sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
|
||||
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user