From add83df5cc336506099f11ad929875e77e494a9e Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 21 Apr 2025 17:26:29 +0000 Subject: [PATCH] Fixed issue with training hidream when batch size is larger than 1 --- .../diffusion_models/hidream/hidream_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/extensions_built_in/diffusion_models/hidream/hidream_model.py b/extensions_built_in/diffusion_models/hidream/hidream_model.py index 5d07eb6c..9dd860c4 100644 --- a/extensions_built_in/diffusion_models/hidream/hidream_model.py +++ b/extensions_built_in/diffusion_models/hidream/hidream_model.py @@ -329,6 +329,7 @@ class HidreamModel(BaseModel): text_embeddings: PromptEmbeds, **kwargs ): + batch_size = latent_model_input.shape[0] with torch.no_grad(): if latent_model_input.shape[-2] != latent_model_input.shape[-1]: B, C, H, W = latent_model_input.shape @@ -342,8 +343,10 @@ class HidreamModel(BaseModel): img_ids_pad = torch.zeros(self.transformer.max_seq, 3) img_ids_pad[:pH*pW, :] = img_ids - img_sizes = img_sizes.unsqueeze(0).to(latent_model_input.device) + img_sizes = img_sizes.unsqueeze(0).to(latent_model_input.device) + img_sizes = torch.cat([img_sizes] * batch_size, dim=0) img_ids = img_ids_pad.unsqueeze(0).to(latent_model_input.device) + img_ids = torch.cat([img_ids] * batch_size, dim=0) else: img_sizes = img_ids = None