Hidream is training, but has a memory leak

This commit is contained in:
Jaret Burkett
2025-04-13 23:28:18 +00:00
parent 594e166ca3
commit f80cf99f40
6 changed files with 86 additions and 89 deletions

View File

@@ -725,9 +725,13 @@ class BaseModel:
do_classifier_free_guidance = True
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
if isinstance(text_embeddings.text_embeds, list):
te_batch_size = text_embeddings.text_embeds[0].shape[0]
else:
te_batch_size = text_embeddings.text_embeds.shape[0]
if latents.shape[0] == te_batch_size:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
elif latents.shape[0] * 2 != te_batch_size:
raise ValueError(
"Batch size of latents must be the same or half the batch size of text embeddings")
latents = latents.to(self.device_torch)

View File

@@ -36,7 +36,10 @@ class PromptEmbeds:
self.attention_mask = attention_mask
def to(self, *args, **kwargs):
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple):
self.text_embeds = [t.to(*args, **kwargs) for t in self.text_embeds]
else:
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
if self.attention_mask is not None:
@@ -45,7 +48,10 @@ class PromptEmbeds:
def detach(self):
new_embeds = self.clone()
new_embeds.text_embeds = new_embeds.text_embeds.detach()
if isinstance(new_embeds.text_embeds, list) or isinstance(new_embeds.text_embeds, tuple):
new_embeds.text_embeds = [t.detach() for t in new_embeds.text_embeds]
else:
new_embeds.text_embeds = new_embeds.text_embeds.detach()
if new_embeds.pooled_embeds is not None:
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
if new_embeds.attention_mask is not None:
@@ -53,10 +59,14 @@ class PromptEmbeds:
return new_embeds
def clone(self):
if self.pooled_embeds is not None:
prompt_embeds = PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple):
cloned_text_embeds = [t.clone() for t in self.text_embeds]
else:
prompt_embeds = PromptEmbeds(self.text_embeds.clone())
cloned_text_embeds = self.text_embeds.clone()
if self.pooled_embeds is not None:
prompt_embeds = PromptEmbeds([cloned_text_embeds, self.pooled_embeds.clone()])
else:
prompt_embeds = PromptEmbeds(cloned_text_embeds)
if self.attention_mask is not None:
prompt_embeds.attention_mask = self.attention_mask.clone()
@@ -64,12 +74,18 @@ class PromptEmbeds:
def expand_to_batch(self, batch_size):
pe = self.clone()
current_batch_size = pe.text_embeds.shape[0]
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
current_batch_size = pe.text_embeds[0].shape[0]
else:
current_batch_size = pe.text_embeds.shape[0]
if current_batch_size == batch_size:
return pe
if current_batch_size != 1:
raise Exception("Can only expand batch size for batch size 1")
pe.text_embeds = pe.text_embeds.expand(batch_size, -1)
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
pe.text_embeds = [t.expand(batch_size, -1) for t in pe.text_embeds]
else:
pe.text_embeds = pe.text_embeds.expand(batch_size, -1)
if pe.pooled_embeds is not None:
pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1)
if pe.attention_mask is not None:
@@ -145,7 +161,13 @@ class EncodedPromptPair:
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
if isinstance(prompt_embeds[0].text_embeds, list) or isinstance(prompt_embeds[0].text_embeds, tuple):
embed_list = []
for i in range(len(prompt_embeds[0].text_embeds)):
embed_list.append(torch.cat([p.text_embeds[i] for p in prompt_embeds], dim=0))
text_embeds = embed_list
else:
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
pooled_embeds = None
if prompt_embeds[0].pooled_embeds is not None:
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
@@ -196,7 +218,16 @@ def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[Prom
if num_parts is None:
# use batch size
num_parts = concatenated.text_embeds.shape[0]
text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
if isinstance(concatenated.text_embeds, list) or isinstance(concatenated.text_embeds, tuple):
# split each part
text_embeds_splits = [
torch.chunk(text, num_parts, dim=0)
for text in concatenated.text_embeds
]
text_embeds_splits = list(zip(*text_embeds_splits))
else:
text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
if concatenated.pooled_embeds is not None:
pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0)