mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed issue with training hidream when batch size is larger than 1
This commit is contained in:
@@ -329,6 +329,7 @@ class HidreamModel(BaseModel):
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
):
|
||||
batch_size = latent_model_input.shape[0]
|
||||
with torch.no_grad():
|
||||
if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
|
||||
B, C, H, W = latent_model_input.shape
|
||||
@@ -342,8 +343,10 @@ class HidreamModel(BaseModel):
|
||||
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
|
||||
img_ids_pad[:pH*pW, :] = img_ids
|
||||
|
||||
img_sizes = img_sizes.unsqueeze(0).to(latent_model_input.device)
|
||||
img_sizes = img_sizes.unsqueeze(0).to(latent_model_input.device)
|
||||
img_sizes = torch.cat([img_sizes] * batch_size, dim=0)
|
||||
img_ids = img_ids_pad.unsqueeze(0).to(latent_model_input.device)
|
||||
img_ids = torch.cat([img_ids] * batch_size, dim=0)
|
||||
else:
|
||||
img_sizes = img_ids = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user