Fixed issue with training hidream when batch size is larger than 1

This commit is contained in:
Jaret Burkett
2025-04-21 17:26:29 +00:00
parent 12e3095d8a
commit add83df5cc

View File

@@ -329,6 +329,7 @@ class HidreamModel(BaseModel):
text_embeddings: PromptEmbeds,
**kwargs
):
batch_size = latent_model_input.shape[0]
with torch.no_grad():
if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
B, C, H, W = latent_model_input.shape
@@ -342,8 +343,10 @@ class HidreamModel(BaseModel):
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
img_ids_pad[:pH*pW, :] = img_ids
img_sizes = img_sizes.unsqueeze(0).to(latent_model_input.device)
img_sizes = img_sizes.unsqueeze(0).to(latent_model_input.device)
img_sizes = torch.cat([img_sizes] * batch_size, dim=0)
img_ids = img_ids_pad.unsqueeze(0).to(latent_model_input.device)
img_ids = torch.cat([img_ids] * batch_size, dim=0)
else:
img_sizes = img_ids = None