From 257ac2653a565672b280f2851f37b1ba6e546548 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Wed, 7 Feb 2024 00:44:12 -0800 Subject: [PATCH] rework sigma device mapping --- modules/sd_samplers_kdiffusion.py | 12 ++++++------ modules/sd_samplers_lcm.py | 2 +- modules/sd_samplers_timesteps.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 39727e2d..887d180d 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -144,12 +144,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.load_device) - self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.load_device) + self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device) + self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - sigmas = self.get_sigmas(p, steps).to(shared.device) + sigmas = self.get_sigmas(p, steps).to(x.device) sigma_sched = sigmas[steps - t_enc - 1:] x = x.to(noise) @@ -206,12 +206,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.load_device) - self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.load_device) + self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device) + self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device) steps = steps or p.steps - sigmas = self.get_sigmas(p, steps).to(shared.device) + sigmas = self.get_sigmas(p, steps).to(x.device) if opts.sgm_noise_multiplier: p.extra_generation_params["SGM noise multiplier"] = True diff --git a/modules/sd_samplers_lcm.py b/modules/sd_samplers_lcm.py index b1c1e475..29d453a2 100644 --- a/modules/sd_samplers_lcm.py +++ b/modules/sd_samplers_lcm.py @@ -27,7 +27,7 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser): start = self.sigma_to_t(self.sigma_max) end = self.sigma_to_t(self.sigma_min) - t = torch.linspace(start, end, n, device=shared.sd_model.forge_objects.unet.current_device) + t = torch.linspace(start, end, n, device=self.sigmas.device) return sampling.append_zero(self.t_to_sigma(t)) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index f2098378..149d6700 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -101,11 +101,11 @@ class CompVisSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.load_device) + self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(x.device) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - timesteps = self.get_timesteps(p, steps).to(shared.device) + timesteps = self.get_timesteps(p, steps).to(x.device) timesteps_sched = timesteps[:t_enc] alphas_cumprod = shared.sd_model.alphas_cumprod @@ -151,10 +151,10 @@ class CompVisSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.load_device) + self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(x.device) steps = steps or p.steps - timesteps = self.get_timesteps(p, steps).to(shared.device) + timesteps = self.get_timesteps(p, steps).to(x.device) extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters