mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 03:01:15 +00:00
link k-diffusion to backend
This commit is contained in:
@@ -38,6 +38,41 @@ class VDenoiser(nn.Module):
|
|||||||
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||||
|
|
||||||
|
|
||||||
|
class ForgeScheduleLinker(nn.Module):
|
||||||
|
def __init__(self, predictor):
|
||||||
|
super().__init__()
|
||||||
|
self.predictor = predictor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigmas(self):
|
||||||
|
return self.predictor.sigmas
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log_sigmas(self):
|
||||||
|
return self.predictor.sigmas.log()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.predictor.sigma_min()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.predictor.sigma_max()
|
||||||
|
|
||||||
|
def get_sigmas(self, n=None):
|
||||||
|
if n is None:
|
||||||
|
return sampling.append_zero(self.sigmas.flip(0))
|
||||||
|
t_max = len(self.sigmas) - 1
|
||||||
|
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||||
|
return sampling.append_zero(self.t_to_sigma(t))
|
||||||
|
|
||||||
|
def sigma_to_t(self, sigma, quantize=None):
|
||||||
|
return self.predictor.timestep(sigma)
|
||||||
|
|
||||||
|
def t_to_sigma(self, t):
|
||||||
|
return self.predictor.sigma(t)
|
||||||
|
|
||||||
|
|
||||||
class DiscreteSchedule(nn.Module):
|
class DiscreteSchedule(nn.Module):
|
||||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
||||||
levels."""
|
levels."""
|
||||||
|
|||||||
@@ -56,10 +56,7 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
|||||||
@property
|
@property
|
||||||
def inner_model(self):
|
def inner_model(self):
|
||||||
if self.model_wrap is None:
|
if self.model_wrap is None:
|
||||||
self.model_wrap = k_diffusion.external.DiscreteSchedule(
|
self.model_wrap = k_diffusion.external.ForgeScheduleLinker(shared.sd_model.forge_objects.unet.model.predictor)
|
||||||
sigmas=shared.sd_model.forge_objects.unet.model.predictor.sigmas,
|
|
||||||
quantize=shared.opts.enable_quantization
|
|
||||||
)
|
|
||||||
self.model_wrap.inner_model = shared.sd_model
|
self.model_wrap.inner_model = shared.sd_model
|
||||||
|
|
||||||
return self.model_wrap
|
return self.model_wrap
|
||||||
@@ -136,8 +133,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||||
|
|
||||||
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
|
self.model_wrap.predictor.to(x.device)
|
||||||
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
|
|
||||||
|
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
@@ -198,8 +194,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||||
|
|
||||||
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
|
self.model_wrap.predictor.to(x.device)
|
||||||
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
|
|
||||||
|
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user