diff --git a/k_diffusion/external.py b/k_diffusion/external.py index 79b51cec..7ecc8c25 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -38,6 +38,41 @@ class VDenoiser(nn.Module): return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip +class ForgeScheduleLinker(nn.Module): + def __init__(self, predictor): + super().__init__() + self.predictor = predictor + + @property + def sigmas(self): + return self.predictor.sigmas + + @property + def log_sigmas(self): + return self.predictor.sigmas.log() + + @property + def sigma_min(self): + return self.predictor.sigma_min() + + @property + def sigma_max(self): + return self.predictor.sigma_max() + + def get_sigmas(self, n=None): + if n is None: + return sampling.append_zero(self.sigmas.flip(0)) + t_max = len(self.sigmas) - 1 + t = torch.linspace(t_max, 0, n, device=self.sigmas.device) + return sampling.append_zero(self.t_to_sigma(t)) + + def sigma_to_t(self, sigma, quantize=None): + return self.predictor.timestep(sigma) + + def t_to_sigma(self, t): + return self.predictor.sigma(t) + + class DiscreteSchedule(nn.Module): """A mapping between continuous noise levels (sigmas) and a list of discrete noise levels.""" diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 146e9c3a..5d304a28 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -56,10 +56,7 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): @property def inner_model(self): if self.model_wrap is None: - self.model_wrap = k_diffusion.external.DiscreteSchedule( - sigmas=shared.sd_model.forge_objects.unet.model.predictor.sigmas, - quantize=shared.opts.enable_quantization - ) + self.model_wrap = k_diffusion.external.ForgeScheduleLinker(shared.sd_model.forge_objects.unet.model.predictor) self.model_wrap.inner_model = shared.sd_model return self.model_wrap @@ -136,8 +133,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device) - self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device) + self.model_wrap.predictor.to(x.device) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) @@ -198,8 +194,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device) - self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device) + self.model_wrap.predictor.to(x.device) steps = steps or p.steps