fix sigmas device in rare cases #71

This commit is contained in:
lllyasviel
2024-02-06 23:12:35 -08:00
parent 4ea4a92fe9
commit d11c9d7506
2 changed files with 10 additions and 10 deletions

View File

@@ -144,12 +144,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.current_device)
self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.current_device)
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.load_device)
self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.load_device)
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
sigmas = self.get_sigmas(p, steps)
sigmas = self.get_sigmas(p, steps).to(shared.device)
sigma_sched = sigmas[steps - t_enc - 1:]
x = x.to(noise)
@@ -206,12 +206,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.current_device)
self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.current_device)
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.load_device)
self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.load_device)
steps = steps or p.steps
sigmas = self.get_sigmas(p, steps)
sigmas = self.get_sigmas(p, steps).to(shared.device)
if opts.sgm_noise_multiplier:
p.extra_generation_params["SGM noise multiplier"] = True

View File

@@ -101,11 +101,11 @@ class CompVisSampler(sd_samplers_common.Sampler):
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device)
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.load_device)
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
timesteps = self.get_timesteps(p, steps)
timesteps = self.get_timesteps(p, steps).to(shared.device)
timesteps_sched = timesteps[:t_enc]
alphas_cumprod = shared.sd_model.alphas_cumprod
@@ -151,10 +151,10 @@ class CompVisSampler(sd_samplers_common.Sampler):
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device)
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.load_device)
steps = steps or p.steps
timesteps = self.get_timesteps(p, steps)
timesteps = self.get_timesteps(p, steps).to(shared.device)
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters