diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 91580994..39727e2d 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -144,12 +144,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.current_device) - self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.current_device) + self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.load_device) + self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.load_device) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - sigmas = self.get_sigmas(p, steps) + sigmas = self.get_sigmas(p, steps).to(shared.device) sigma_sched = sigmas[steps - t_enc - 1:] x = x.to(noise) @@ -206,12 +206,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.current_device) - self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.current_device) + self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.load_device) + self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.load_device) steps = steps or p.steps - sigmas = self.get_sigmas(p, steps) + sigmas = self.get_sigmas(p, steps).to(shared.device) if opts.sgm_noise_multiplier: p.extra_generation_params["SGM noise multiplier"] = True diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index cf62f53f..f2098378 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -101,11 +101,11 @@ class CompVisSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device) + self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.load_device) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - timesteps = self.get_timesteps(p, steps) + timesteps = self.get_timesteps(p, steps).to(shared.device) timesteps_sched = timesteps[:t_enc] alphas_cumprod = shared.sd_model.alphas_cumprod @@ -151,10 +151,10 @@ class CompVisSampler(sd_samplers_common.Sampler): unet_patcher = self.model_wrap.inner_model.forge_objects.unet sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) - self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device) + self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.load_device) steps = steps or p.steps - timesteps = self.get_timesteps(p, steps) + timesteps = self.get_timesteps(p, steps).to(shared.device) extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters