Hidream is training, but has a memory leak

This commit is contained in:
Jaret Burkett
2025-04-13 23:28:18 +00:00
parent 594e166ca3
commit f80cf99f40
6 changed files with 86 additions and 89 deletions

View File

@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.6.3-base-ubuntu22.04
FROM nvidia/cuda:12.6.3-devel-ubuntu22.04
LABEL authors="jaret"

View File

@@ -122,11 +122,6 @@ class HidreamModel(BaseModel):
flush()
self.print_and_status_update("Loading transformer")
transformer_kwargs = {}
if self.model_config.quantize:
quant_type = f"{self.model_config.qtype}wo"
transformer_kwargs['quantization_config'] = TorchAoConfig(quant_type)
transformer = HiDreamImageTransformer2DModel.from_pretrained(
model_path,
@@ -352,9 +347,10 @@ class HidreamModel(BaseModel):
else:
img_sizes = img_ids = None
cast_dtype = self.model.dtype
dtype = self.model.dtype
device = self.device_torch
# nosie pred here
# Pack the latent
if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
B, C, H, W = latent_model_input.shape
patch_size = self.transformer.config.patch_size
@@ -365,14 +361,18 @@ class HidreamModel(BaseModel):
device=latent_model_input.device
)
latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size)
out[:, :, 0:pH*pW] = latent_model_input
out[:, :, 0:pH*pW] = latent_model_input
latent_model_input = out
text_embeds = text_embeddings.text_embeds
# run the to for the list
text_embeds = [te.to(device, dtype=dtype) for te in text_embeds]
noise_pred = self.transformer(
hidden_states = latent_model_input,
timesteps = timestep,
encoder_hidden_states = text_embeddings.text_embeds.to(cast_dtype, dtype=cast_dtype),
pooled_embeds = text_embeddings.pooled_embeds.text_embeds.to(cast_dtype, dtype=cast_dtype),
encoder_hidden_states = text_embeds,
pooled_embeds = text_embeddings.pooled_embeds.to(device, dtype=dtype),
img_sizes = img_sizes,
img_ids = img_ids,
return_dict = False,
@@ -395,9 +395,8 @@ class HidreamModel(BaseModel):
max_sequence_length = max_sequence_length,
)
pe = PromptEmbeds(
prompt_embeds
[prompt_embeds, pooled_prompt_embeds]
)
pe.pooled_embeds = pooled_prompt_embeds
return pe
def get_model_has_grad(self):

View File

@@ -304,26 +304,6 @@ class HiDreamImageTransformer2DModel(
self.gradient_checkpointing = False
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing(
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
) -> None:
is_gradient_checkpointing_set = False
for name, module in self.named_modules():
if hasattr(module, "gradient_checkpointing"):
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
)
def expand_timesteps(self, timesteps, batch_size, device):
if not torch.is_tensor(timesteps):
@@ -339,18 +319,20 @@ class HiDreamImageTransformer2DModel(
timesteps = timesteps.expand(batch_size)
return timesteps
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
if is_training:
x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size)
else:
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.config.patch_size, p2=self.config.patch_size)
# the implementation on hidream during train was wrong, just use the inference one.
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
# Process all images in the batch according to their specific dimensions
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(
x[i, :pH*pW].reshape(1, pH, pW, -1),
'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.config.patch_size, p2=self.config.patch_size
)
x = torch.cat(x_arr, dim=0)
)
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, x, max_seq, img_sizes=None):
@@ -452,27 +434,18 @@ class HiDreamImageTransformer2DModel(
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
image_tokens_masks,
cur_encoder_hidden_states,
adaln_input,
rope,
**ckpt_kwargs,
adaln_input.clone(),
rope.clone(),
)
else:
hidden_states, initial_encoder_hidden_states = block(
image_tokens = hidden_states,
@@ -495,26 +468,16 @@ class HiDreamImageTransformer2DModel(
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
image_tokens_masks,
None,
adaln_input,
rope,
**ckpt_kwargs,
adaln_input.clone(),
rope.clone(),
)
else:
hidden_states = block(

View File

@@ -313,7 +313,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = prompt_embeds[0].shape[0]
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
prompt = prompt,
@@ -561,7 +561,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = prompt_embeds[0].shape[0]
device = self._execution_device

View File

@@ -725,9 +725,13 @@ class BaseModel:
do_classifier_free_guidance = True
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
if isinstance(text_embeddings.text_embeds, list):
te_batch_size = text_embeddings.text_embeds[0].shape[0]
else:
te_batch_size = text_embeddings.text_embeds.shape[0]
if latents.shape[0] == te_batch_size:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
elif latents.shape[0] * 2 != te_batch_size:
raise ValueError(
"Batch size of latents must be the same or half the batch size of text embeddings")
latents = latents.to(self.device_torch)

View File

@@ -36,7 +36,10 @@ class PromptEmbeds:
self.attention_mask = attention_mask
def to(self, *args, **kwargs):
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple):
self.text_embeds = [t.to(*args, **kwargs) for t in self.text_embeds]
else:
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
if self.attention_mask is not None:
@@ -45,7 +48,10 @@ class PromptEmbeds:
def detach(self):
new_embeds = self.clone()
new_embeds.text_embeds = new_embeds.text_embeds.detach()
if isinstance(new_embeds.text_embeds, list) or isinstance(new_embeds.text_embeds, tuple):
new_embeds.text_embeds = [t.detach() for t in new_embeds.text_embeds]
else:
new_embeds.text_embeds = new_embeds.text_embeds.detach()
if new_embeds.pooled_embeds is not None:
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
if new_embeds.attention_mask is not None:
@@ -53,10 +59,14 @@ class PromptEmbeds:
return new_embeds
def clone(self):
if self.pooled_embeds is not None:
prompt_embeds = PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple):
cloned_text_embeds = [t.clone() for t in self.text_embeds]
else:
prompt_embeds = PromptEmbeds(self.text_embeds.clone())
cloned_text_embeds = self.text_embeds.clone()
if self.pooled_embeds is not None:
prompt_embeds = PromptEmbeds([cloned_text_embeds, self.pooled_embeds.clone()])
else:
prompt_embeds = PromptEmbeds(cloned_text_embeds)
if self.attention_mask is not None:
prompt_embeds.attention_mask = self.attention_mask.clone()
@@ -64,12 +74,18 @@ class PromptEmbeds:
def expand_to_batch(self, batch_size):
pe = self.clone()
current_batch_size = pe.text_embeds.shape[0]
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
current_batch_size = pe.text_embeds[0].shape[0]
else:
current_batch_size = pe.text_embeds.shape[0]
if current_batch_size == batch_size:
return pe
if current_batch_size != 1:
raise Exception("Can only expand batch size for batch size 1")
pe.text_embeds = pe.text_embeds.expand(batch_size, -1)
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
pe.text_embeds = [t.expand(batch_size, -1) for t in pe.text_embeds]
else:
pe.text_embeds = pe.text_embeds.expand(batch_size, -1)
if pe.pooled_embeds is not None:
pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1)
if pe.attention_mask is not None:
@@ -145,7 +161,13 @@ class EncodedPromptPair:
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
if isinstance(prompt_embeds[0].text_embeds, list) or isinstance(prompt_embeds[0].text_embeds, tuple):
embed_list = []
for i in range(len(prompt_embeds[0].text_embeds)):
embed_list.append(torch.cat([p.text_embeds[i] for p in prompt_embeds], dim=0))
text_embeds = embed_list
else:
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
pooled_embeds = None
if prompt_embeds[0].pooled_embeds is not None:
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
@@ -196,7 +218,16 @@ def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[Prom
if num_parts is None:
# use batch size
num_parts = concatenated.text_embeds.shape[0]
text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
if isinstance(concatenated.text_embeds, list) or isinstance(concatenated.text_embeds, tuple):
# split each part
text_embeds_splits = [
torch.chunk(text, num_parts, dim=0)
for text in concatenated.text_embeds
]
text_embeds_splits = list(zip(*text_embeds_splits))
else:
text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
if concatenated.pooled_embeds is not None:
pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0)