diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 01533d41..dcdb2638 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -444,8 +444,8 @@ class StableDiffusion: if do_classifier_free_guidance: # todo check this with larget batches - train_util.concat_embeddings( - add_time_ids, add_time_ids, 1 + add_time_ids = train_util.concat_embeddings( + add_time_ids, add_time_ids, int(latents.shape[0]) ) else: # concat to fit batch size