Hidream is training, but has a memory leak

This commit is contained in:
Jaret Burkett
2025-04-13 23:28:18 +00:00
parent 594e166ca3
commit f80cf99f40
6 changed files with 86 additions and 89 deletions

View File

@@ -304,26 +304,6 @@ class HiDreamImageTransformer2DModel(
self.gradient_checkpointing = False
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing(
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
) -> None:
is_gradient_checkpointing_set = False
for name, module in self.named_modules():
if hasattr(module, "gradient_checkpointing"):
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
)
def expand_timesteps(self, timesteps, batch_size, device):
if not torch.is_tensor(timesteps):
@@ -339,18 +319,20 @@ class HiDreamImageTransformer2DModel(
timesteps = timesteps.expand(batch_size)
return timesteps
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
if is_training:
x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size)
else:
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.config.patch_size, p2=self.config.patch_size)
# the implementation on hidream during train was wrong, just use the inference one.
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
# Process all images in the batch according to their specific dimensions
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(
x[i, :pH*pW].reshape(1, pH, pW, -1),
'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.config.patch_size, p2=self.config.patch_size
)
x = torch.cat(x_arr, dim=0)
)
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, x, max_seq, img_sizes=None):
@@ -452,27 +434,18 @@ class HiDreamImageTransformer2DModel(
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
image_tokens_masks,
cur_encoder_hidden_states,
adaln_input,
rope,
**ckpt_kwargs,
adaln_input.clone(),
rope.clone(),
)
else:
hidden_states, initial_encoder_hidden_states = block(
image_tokens = hidden_states,
@@ -495,26 +468,16 @@ class HiDreamImageTransformer2DModel(
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
image_tokens_masks,
None,
adaln_input,
rope,
**ckpt_kwargs,
adaln_input.clone(),
rope.clone(),
)
else:
hidden_states = block(

View File

@@ -313,7 +313,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = prompt_embeds[0].shape[0]
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
prompt = prompt,
@@ -561,7 +561,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = prompt_embeds[0].shape[0]
device = self._execution_device