diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index ba94dd82..0071cc87 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -648,8 +648,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if eta: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + h_last = h + old_denoised = denoised - h_last = h return x @@ -698,8 +699,9 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if eta: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise + h_1, h_2 = h, h_1 + denoised_1, denoised_2 = denoised, denoised_1 - h_1, h_2 = h, h_1 return x @torch.no_grad() @@ -910,4 +912,4 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, else: buffer_model.append(d_cur.detach()) - return x_next \ No newline at end of file + return x_next