Hidream is training, but has a memory leak

This commit is contained in:
Jaret Burkett
2025-04-13 23:28:18 +00:00
parent 594e166ca3
commit f80cf99f40
6 changed files with 86 additions and 89 deletions

View File

@@ -122,11 +122,6 @@ class HidreamModel(BaseModel):
flush()
self.print_and_status_update("Loading transformer")
transformer_kwargs = {}
if self.model_config.quantize:
quant_type = f"{self.model_config.qtype}wo"
transformer_kwargs['quantization_config'] = TorchAoConfig(quant_type)
transformer = HiDreamImageTransformer2DModel.from_pretrained(
model_path,
@@ -352,9 +347,10 @@ class HidreamModel(BaseModel):
else:
img_sizes = img_ids = None
cast_dtype = self.model.dtype
dtype = self.model.dtype
device = self.device_torch
# nosie pred here
# Pack the latent
if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
B, C, H, W = latent_model_input.shape
patch_size = self.transformer.config.patch_size
@@ -365,14 +361,18 @@ class HidreamModel(BaseModel):
device=latent_model_input.device
)
latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size)
out[:, :, 0:pH*pW] = latent_model_input
out[:, :, 0:pH*pW] = latent_model_input
latent_model_input = out
text_embeds = text_embeddings.text_embeds
# run the to for the list
text_embeds = [te.to(device, dtype=dtype) for te in text_embeds]
noise_pred = self.transformer(
hidden_states = latent_model_input,
timesteps = timestep,
encoder_hidden_states = text_embeddings.text_embeds.to(cast_dtype, dtype=cast_dtype),
pooled_embeds = text_embeddings.pooled_embeds.text_embeds.to(cast_dtype, dtype=cast_dtype),
encoder_hidden_states = text_embeds,
pooled_embeds = text_embeddings.pooled_embeds.to(device, dtype=dtype),
img_sizes = img_sizes,
img_ids = img_ids,
return_dict = False,
@@ -395,9 +395,8 @@ class HidreamModel(BaseModel):
max_sequence_length = max_sequence_length,
)
pe = PromptEmbeds(
prompt_embeds
[prompt_embeds, pooled_prompt_embeds]
)
pe.pooled_embeds = pooled_prompt_embeds
return pe
def get_model_has_grad(self):

View File

@@ -304,26 +304,6 @@ class HiDreamImageTransformer2DModel(
self.gradient_checkpointing = False
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing(
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
) -> None:
is_gradient_checkpointing_set = False
for name, module in self.named_modules():
if hasattr(module, "gradient_checkpointing"):
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
)
def expand_timesteps(self, timesteps, batch_size, device):
if not torch.is_tensor(timesteps):
@@ -339,18 +319,20 @@ class HiDreamImageTransformer2DModel(
timesteps = timesteps.expand(batch_size)
return timesteps
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
if is_training:
x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size)
else:
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.config.patch_size, p2=self.config.patch_size)
# the implementation on hidream during train was wrong, just use the inference one.
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
# Process all images in the batch according to their specific dimensions
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(
x[i, :pH*pW].reshape(1, pH, pW, -1),
'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.config.patch_size, p2=self.config.patch_size
)
x = torch.cat(x_arr, dim=0)
)
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, x, max_seq, img_sizes=None):
@@ -452,27 +434,18 @@ class HiDreamImageTransformer2DModel(
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
image_tokens_masks,
cur_encoder_hidden_states,
adaln_input,
rope,
**ckpt_kwargs,
adaln_input.clone(),
rope.clone(),
)
else:
hidden_states, initial_encoder_hidden_states = block(
image_tokens = hidden_states,
@@ -495,26 +468,16 @@ class HiDreamImageTransformer2DModel(
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
image_tokens_masks,
None,
adaln_input,
rope,
**ckpt_kwargs,
adaln_input.clone(),
rope.clone(),
)
else:
hidden_states = block(

View File

@@ -313,7 +313,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = prompt_embeds[0].shape[0]
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
prompt = prompt,
@@ -561,7 +561,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = prompt_embeds[0].shape[0]
device = self._execution_device