This commit is contained in:
lllyasviel
2024-02-06 17:46:23 -08:00
parent 6301a6660e
commit e579fab4d0

View File

@@ -154,6 +154,9 @@ class CFGDenoiser(torch.nn.Module):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
original_x_device = x.device
original_x_dtype = x.dtype
if self.classic_ddim_eps_estimation:
acd = self.inner_model.inner_model.alphas_cumprod
fake_sigmas = ((1 - acd) / acd) ** 0.5
@@ -195,5 +198,5 @@ class CFGDenoiser(torch.nn.Module):
eps = (x - denoised) / sigma
return eps
return denoised
return denoised.to(device=original_x_device, dtype=original_x_dtype)