Hidream is training, but has a memory leak

This commit is contained in:
Jaret Burkett
2025-04-13 23:28:18 +00:00
parent 594e166ca3
commit f80cf99f40
6 changed files with 86 additions and 89 deletions

View File

@@ -725,9 +725,13 @@ class BaseModel:
do_classifier_free_guidance = True
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
if isinstance(text_embeddings.text_embeds, list):
te_batch_size = text_embeddings.text_embeds[0].shape[0]
else:
te_batch_size = text_embeddings.text_embeds.shape[0]
if latents.shape[0] == te_batch_size:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
elif latents.shape[0] * 2 != te_batch_size:
raise ValueError(
"Batch size of latents must be the same or half the batch size of text embeddings")
latents = latents.to(self.device_torch)