mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Hidream is training, but has a memory leak
This commit is contained in:
@@ -122,11 +122,6 @@ class HidreamModel(BaseModel):
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
|
||||
transformer_kwargs = {}
|
||||
if self.model_config.quantize:
|
||||
quant_type = f"{self.model_config.qtype}wo"
|
||||
transformer_kwargs['quantization_config'] = TorchAoConfig(quant_type)
|
||||
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained(
|
||||
model_path,
|
||||
@@ -352,9 +347,10 @@ class HidreamModel(BaseModel):
|
||||
else:
|
||||
img_sizes = img_ids = None
|
||||
|
||||
cast_dtype = self.model.dtype
|
||||
dtype = self.model.dtype
|
||||
device = self.device_torch
|
||||
|
||||
# nosie pred here
|
||||
# Pack the latent
|
||||
if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
|
||||
B, C, H, W = latent_model_input.shape
|
||||
patch_size = self.transformer.config.patch_size
|
||||
@@ -365,14 +361,18 @@ class HidreamModel(BaseModel):
|
||||
device=latent_model_input.device
|
||||
)
|
||||
latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size)
|
||||
out[:, :, 0:pH*pW] = latent_model_input
|
||||
out[:, :, 0:pH*pW] = latent_model_input
|
||||
latent_model_input = out
|
||||
|
||||
text_embeds = text_embeddings.text_embeds
|
||||
# run the to for the list
|
||||
text_embeds = [te.to(device, dtype=dtype) for te in text_embeds]
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states = latent_model_input,
|
||||
timesteps = timestep,
|
||||
encoder_hidden_states = text_embeddings.text_embeds.to(cast_dtype, dtype=cast_dtype),
|
||||
pooled_embeds = text_embeddings.pooled_embeds.text_embeds.to(cast_dtype, dtype=cast_dtype),
|
||||
encoder_hidden_states = text_embeds,
|
||||
pooled_embeds = text_embeddings.pooled_embeds.to(device, dtype=dtype),
|
||||
img_sizes = img_sizes,
|
||||
img_ids = img_ids,
|
||||
return_dict = False,
|
||||
@@ -395,9 +395,8 @@ class HidreamModel(BaseModel):
|
||||
max_sequence_length = max_sequence_length,
|
||||
)
|
||||
pe = PromptEmbeds(
|
||||
prompt_embeds
|
||||
[prompt_embeds, pooled_prompt_embeds]
|
||||
)
|
||||
pe.pooled_embeds = pooled_prompt_embeds
|
||||
return pe
|
||||
|
||||
def get_model_has_grad(self):
|
||||
|
||||
@@ -304,26 +304,6 @@ class HiDreamImageTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
def _set_gradient_checkpointing(
|
||||
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
|
||||
) -> None:
|
||||
is_gradient_checkpointing_set = False
|
||||
|
||||
for name, module in self.named_modules():
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
|
||||
module._gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = enable
|
||||
is_gradient_checkpointing_set = True
|
||||
|
||||
if not is_gradient_checkpointing_set:
|
||||
raise ValueError(
|
||||
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
|
||||
f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
|
||||
)
|
||||
|
||||
def expand_timesteps(self, timesteps, batch_size, device):
|
||||
if not torch.is_tensor(timesteps):
|
||||
@@ -339,18 +319,20 @@ class HiDreamImageTransformer2DModel(
|
||||
timesteps = timesteps.expand(batch_size)
|
||||
return timesteps
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
|
||||
if is_training:
|
||||
x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size)
|
||||
else:
|
||||
x_arr = []
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
pH, pW = img_size
|
||||
x_arr.append(
|
||||
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
|
||||
p1=self.config.patch_size, p2=self.config.patch_size)
|
||||
# the implementation on hidream during train was wrong, just use the inference one.
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
|
||||
# Process all images in the batch according to their specific dimensions
|
||||
x_arr = []
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
pH, pW = img_size
|
||||
x_arr.append(
|
||||
einops.rearrange(
|
||||
x[i, :pH*pW].reshape(1, pH, pW, -1),
|
||||
'B H W (p1 p2 C) -> B C (H p1) (W p2)',
|
||||
p1=self.config.patch_size, p2=self.config.patch_size
|
||||
)
|
||||
x = torch.cat(x_arr, dim=0)
|
||||
)
|
||||
x = torch.cat(x_arr, dim=0)
|
||||
return x
|
||||
|
||||
def patchify(self, x, max_seq, img_sizes=None):
|
||||
@@ -452,27 +434,18 @@ class HiDreamImageTransformer2DModel(
|
||||
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
|
||||
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
|
||||
for bid, block in enumerate(self.double_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
|
||||
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
image_tokens_masks,
|
||||
cur_encoder_hidden_states,
|
||||
adaln_input,
|
||||
rope,
|
||||
**ckpt_kwargs,
|
||||
adaln_input.clone(),
|
||||
rope.clone(),
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states, initial_encoder_hidden_states = block(
|
||||
image_tokens = hidden_states,
|
||||
@@ -495,26 +468,16 @@ class HiDreamImageTransformer2DModel(
|
||||
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
|
||||
|
||||
for bid, block in enumerate(self.single_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
|
||||
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
image_tokens_masks,
|
||||
None,
|
||||
adaln_input,
|
||||
rope,
|
||||
**ckpt_kwargs,
|
||||
adaln_input.clone(),
|
||||
rope.clone(),
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
|
||||
@@ -313,7 +313,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
batch_size = prompt_embeds[0].shape[0]
|
||||
|
||||
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
|
||||
prompt = prompt,
|
||||
@@ -561,7 +561,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
batch_size = prompt_embeds[0].shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
|
||||
Reference in New Issue
Block a user