Update sd_samplers_timesteps.py

This commit is contained in:
lllyasviel
2024-01-25 07:37:59 -08:00
parent fd6369d106
commit 01efa9fb7e

View File

@@ -153,6 +153,11 @@ class CompVisSampler(sd_samplers_common.Sampler):
[unet_patcher],
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
self.model_wrap.inner_model.betas = self.model_wrap.inner_model.betas.to(unet_patcher.current_device)
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device)
self.model_wrap.inner_model.alphas_cumprod_prev = self.model_wrap.inner_model.alphas_cumprod_prev.to(unet_patcher.current_device)
self.model_wrap.inner_model.logvar = self.model_wrap.inner_model.logvar.to(unet_patcher.current_device)
steps = steps or p.steps
timesteps = self.get_timesteps(p, steps)