diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index a5f59693..35c0a37c 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1960,7 +1960,8 @@ class StableDiffusion: else: latents = self.vae.encode(images).latent_dist.sample() # latents = self.vae.encode(images, return_dict=False)[0] - latents = latents * (self.vae.config['scaling_factor'] - self.vae.config['shift_factor']) + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + latents = latents * (self.vae.config['scaling_factor'] - shift) latents = latents.to(device, dtype=dtype) return latents