From d11c9d75064b93b988a8a029f8361056262cd674 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 6 Feb 2024 23:12:35 -0800 Subject: [PATCH] fix sigmas device in rare cases #71 --- modules/sd_samplers_kdiffusion.py | 12 ++++++------ modules/sd_samplers_timesteps.py | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 91580994..39727e2d 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -144,12 +144,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.current_device) - self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.current_device) + self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.load_device) + self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.load_device) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - sigmas = self.get_sigmas(p, steps) + sigmas = self.get_sigmas(p, steps).to(shared.device) sigma_sched = sigmas[steps - t_enc - 1:] x = x.to(noise) @@ -206,12 +206,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.current_device) - self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.current_device) + self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.load_device) + self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.load_device) steps = steps or p.steps - sigmas = self.get_sigmas(p, steps) + sigmas = self.get_sigmas(p, steps).to(shared.device) if opts.sgm_noise_multiplier: p.extra_generation_params["SGM noise multiplier"] = True diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index cf62f53f..f2098378 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -101,11 +101,11 @@ class CompVisSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device) + self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.load_device) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - timesteps = self.get_timesteps(p, steps) + timesteps = self.get_timesteps(p, steps).to(shared.device) timesteps_sched = timesteps[:t_enc] alphas_cumprod = shared.sd_model.alphas_cumprod @@ -151,10 +151,10 @@ class CompVisSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device) + self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.load_device) steps = steps or p.steps - timesteps = self.get_timesteps(p, steps) + timesteps = self.get_timesteps(p, steps).to(shared.device) extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters