mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Hidream is training, but has a memory leak
This commit is contained in:
@@ -304,26 +304,6 @@ class HiDreamImageTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
def _set_gradient_checkpointing(
|
||||
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
|
||||
) -> None:
|
||||
is_gradient_checkpointing_set = False
|
||||
|
||||
for name, module in self.named_modules():
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
|
||||
module._gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = enable
|
||||
is_gradient_checkpointing_set = True
|
||||
|
||||
if not is_gradient_checkpointing_set:
|
||||
raise ValueError(
|
||||
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
|
||||
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
|
||||
)
|
||||
|
||||
def expand_timesteps(self, timesteps, batch_size, device):
|
||||
if not torch.is_tensor(timesteps):
|
||||
@@ -339,18 +319,20 @@ class HiDreamImageTransformer2DModel(
|
||||
timesteps = timesteps.expand(batch_size)
|
||||
return timesteps
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
|
||||
if is_training:
|
||||
x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size)
|
||||
else:
|
||||
x_arr = []
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
pH, pW = img_size
|
||||
x_arr.append(
|
||||
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
|
||||
p1=self.config.patch_size, p2=self.config.patch_size)
|
||||
# the implementation on hidream during train was wrong, just use the inference one.
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
|
||||
# Process all images in the batch according to their specific dimensions
|
||||
x_arr = []
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
pH, pW = img_size
|
||||
x_arr.append(
|
||||
einops.rearrange(
|
||||
x[i, :pH*pW].reshape(1, pH, pW, -1),
|
||||
'B H W (p1 p2 C) -> B C (H p1) (W p2)',
|
||||
p1=self.config.patch_size, p2=self.config.patch_size
|
||||
)
|
||||
x = torch.cat(x_arr, dim=0)
|
||||
)
|
||||
x = torch.cat(x_arr, dim=0)
|
||||
return x
|
||||
|
||||
def patchify(self, x, max_seq, img_sizes=None):
|
||||
@@ -452,27 +434,18 @@ class HiDreamImageTransformer2DModel(
|
||||
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
|
||||
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
|
||||
for bid, block in enumerate(self.double_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
|
||||
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
image_tokens_masks,
|
||||
cur_encoder_hidden_states,
|
||||
adaln_input,
|
||||
rope,
|
||||
**ckpt_kwargs,
|
||||
adaln_input.clone(),
|
||||
rope.clone(),
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states, initial_encoder_hidden_states = block(
|
||||
image_tokens = hidden_states,
|
||||
@@ -495,26 +468,16 @@ class HiDreamImageTransformer2DModel(
|
||||
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
|
||||
|
||||
for bid, block in enumerate(self.single_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
|
||||
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
image_tokens_masks,
|
||||
None,
|
||||
adaln_input,
|
||||
rope,
|
||||
**ckpt_kwargs,
|
||||
adaln_input.clone(),
|
||||
rope.clone(),
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
|
||||
@@ -313,7 +313,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
batch_size = prompt_embeds[0].shape[0]
|
||||
|
||||
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
|
||||
prompt = prompt,
|
||||
@@ -561,7 +561,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
batch_size = prompt_embeds[0].shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
|
||||
Reference in New Issue
Block a user