diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 65a3bcbd..cbd0d4fb 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -507,7 +507,7 @@ class StableDiffusion: latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) # check if we need to concat timesteps - if isinstance(timestep, torch.Tensor): + if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: ts_bs = timestep.shape[0] if ts_bs != latent_model_input.shape[0]: if ts_bs == 1: