diff --git a/docker/Dockerfile b/docker/Dockerfile index bd6c1a9d..0d95b254 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.6.3-base-ubuntu22.04 +FROM nvidia/cuda:12.6.3-devel-ubuntu22.04 LABEL authors="jaret" diff --git a/extensions_built_in/diffusion_models/hidream/hidream_model.py b/extensions_built_in/diffusion_models/hidream/hidream_model.py index 03aaa610..3322a111 100644 --- a/extensions_built_in/diffusion_models/hidream/hidream_model.py +++ b/extensions_built_in/diffusion_models/hidream/hidream_model.py @@ -122,11 +122,6 @@ class HidreamModel(BaseModel): flush() self.print_and_status_update("Loading transformer") - - transformer_kwargs = {} - if self.model_config.quantize: - quant_type = f"{self.model_config.qtype}wo" - transformer_kwargs['quantization_config'] = TorchAoConfig(quant_type) transformer = HiDreamImageTransformer2DModel.from_pretrained( model_path, @@ -352,9 +347,10 @@ class HidreamModel(BaseModel): else: img_sizes = img_ids = None - cast_dtype = self.model.dtype + dtype = self.model.dtype + device = self.device_torch - # nosie pred here + # Pack the latent if latent_model_input.shape[-2] != latent_model_input.shape[-1]: B, C, H, W = latent_model_input.shape patch_size = self.transformer.config.patch_size @@ -365,14 +361,18 @@ class HidreamModel(BaseModel): device=latent_model_input.device ) latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) - out[:, :, 0:pH*pW] = latent_model_input + out[:, :, 0:pH*pW] = latent_model_input latent_model_input = out + text_embeds = text_embeddings.text_embeds + # run the to for the list + text_embeds = [te.to(device, dtype=dtype) for te in text_embeds] + noise_pred = self.transformer( hidden_states = latent_model_input, timesteps = timestep, - encoder_hidden_states = text_embeddings.text_embeds.to(cast_dtype, dtype=cast_dtype), - pooled_embeds = text_embeddings.pooled_embeds.text_embeds.to(cast_dtype, dtype=cast_dtype), + encoder_hidden_states = text_embeds, + pooled_embeds = text_embeddings.pooled_embeds.to(device, dtype=dtype), img_sizes = img_sizes, img_ids = img_ids, return_dict = False, @@ -395,9 +395,8 @@ class HidreamModel(BaseModel): max_sequence_length = max_sequence_length, ) pe = PromptEmbeds( - prompt_embeds + [prompt_embeds, pooled_prompt_embeds] ) - pe.pooled_embeds = pooled_prompt_embeds return pe def get_model_has_grad(self): diff --git a/extensions_built_in/diffusion_models/hidream/src/models/transformers/transformer_hidream_image.py b/extensions_built_in/diffusion_models/hidream/src/models/transformers/transformer_hidream_image.py index 58c5daa6..f7eb1045 100644 --- a/extensions_built_in/diffusion_models/hidream/src/models/transformers/transformer_hidream_image.py +++ b/extensions_built_in/diffusion_models/hidream/src/models/transformers/transformer_hidream_image.py @@ -304,26 +304,6 @@ class HiDreamImageTransformer2DModel( self.gradient_checkpointing = False - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - def _set_gradient_checkpointing( - self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint - ) -> None: - is_gradient_checkpointing_set = False - - for name, module in self.named_modules(): - if hasattr(module, "gradient_checkpointing"): - logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'") - module._gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = enable - is_gradient_checkpointing_set = True - - if not is_gradient_checkpointing_set: - raise ValueError( - f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to " - f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`." - ) def expand_timesteps(self, timesteps, batch_size, device): if not torch.is_tensor(timesteps): @@ -339,18 +319,20 @@ class HiDreamImageTransformer2DModel( timesteps = timesteps.expand(batch_size) return timesteps - def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: - if is_training: - x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size) - else: - x_arr = [] - for i, img_size in enumerate(img_sizes): - pH, pW = img_size - x_arr.append( - einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)', - p1=self.config.patch_size, p2=self.config.patch_size) + # the implementation on hidream during train was wrong, just use the inference one. + def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: + # Process all images in the batch according to their specific dimensions + x_arr = [] + for i, img_size in enumerate(img_sizes): + pH, pW = img_size + x_arr.append( + einops.rearrange( + x[i, :pH*pW].reshape(1, pH, pW, -1), + 'B H W (p1 p2 C) -> B C (H p1) (W p2)', + p1=self.config.patch_size, p2=self.config.patch_size ) - x = torch.cat(x_arr, dim=0) + ) + x = torch.cat(x_arr, dim=0) return x def patchify(self, x, max_seq, img_sizes=None): @@ -452,27 +434,18 @@ class HiDreamImageTransformer2DModel( initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] for bid, block in enumerate(self.double_stream_blocks): - cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach() cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, image_tokens_masks, cur_encoder_hidden_states, - adaln_input, - rope, - **ckpt_kwargs, + adaln_input.clone(), + rope.clone(), ) + else: hidden_states, initial_encoder_hidden_states = block( image_tokens = hidden_states, @@ -495,26 +468,16 @@ class HiDreamImageTransformer2DModel( image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) for bid, block in enumerate(self.single_stream_blocks): - cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach() hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, image_tokens_masks, None, - adaln_input, - rope, - **ckpt_kwargs, + adaln_input.clone(), + rope.clone(), ) else: hidden_states = block( diff --git a/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image.py b/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image.py index 7cc55846..9c4e51fa 100644 --- a/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image.py +++ b/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image.py @@ -313,7 +313,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin): if prompt is not None: batch_size = len(prompt) else: - batch_size = prompt_embeds.shape[0] + batch_size = prompt_embeds[0].shape[0] prompt_embeds, pooled_prompt_embeds = self._encode_prompt( prompt = prompt, @@ -561,7 +561,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin): elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - batch_size = prompt_embeds.shape[0] + batch_size = prompt_embeds[0].shape[0] device = self._execution_device diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index ddc8ce8b..61c4bcf6 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -725,9 +725,13 @@ class BaseModel: do_classifier_free_guidance = True # check if batch size of embeddings matches batch size of latents - if latents.shape[0] == text_embeddings.text_embeds.shape[0]: + if isinstance(text_embeddings.text_embeds, list): + te_batch_size = text_embeddings.text_embeds[0].shape[0] + else: + te_batch_size = text_embeddings.text_embeds.shape[0] + if latents.shape[0] == te_batch_size: do_classifier_free_guidance = False - elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: + elif latents.shape[0] * 2 != te_batch_size: raise ValueError( "Batch size of latents must be the same or half the batch size of text embeddings") latents = latents.to(self.device_torch) diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index eb213cca..ff5a68f3 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -36,7 +36,10 @@ class PromptEmbeds: self.attention_mask = attention_mask def to(self, *args, **kwargs): - self.text_embeds = self.text_embeds.to(*args, **kwargs) + if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple): + self.text_embeds = [t.to(*args, **kwargs) for t in self.text_embeds] + else: + self.text_embeds = self.text_embeds.to(*args, **kwargs) if self.pooled_embeds is not None: self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) if self.attention_mask is not None: @@ -45,7 +48,10 @@ class PromptEmbeds: def detach(self): new_embeds = self.clone() - new_embeds.text_embeds = new_embeds.text_embeds.detach() + if isinstance(new_embeds.text_embeds, list) or isinstance(new_embeds.text_embeds, tuple): + new_embeds.text_embeds = [t.detach() for t in new_embeds.text_embeds] + else: + new_embeds.text_embeds = new_embeds.text_embeds.detach() if new_embeds.pooled_embeds is not None: new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach() if new_embeds.attention_mask is not None: @@ -53,10 +59,14 @@ class PromptEmbeds: return new_embeds def clone(self): - if self.pooled_embeds is not None: - prompt_embeds = PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()]) + if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple): + cloned_text_embeds = [t.clone() for t in self.text_embeds] else: - prompt_embeds = PromptEmbeds(self.text_embeds.clone()) + cloned_text_embeds = self.text_embeds.clone() + if self.pooled_embeds is not None: + prompt_embeds = PromptEmbeds([cloned_text_embeds, self.pooled_embeds.clone()]) + else: + prompt_embeds = PromptEmbeds(cloned_text_embeds) if self.attention_mask is not None: prompt_embeds.attention_mask = self.attention_mask.clone() @@ -64,12 +74,18 @@ class PromptEmbeds: def expand_to_batch(self, batch_size): pe = self.clone() - current_batch_size = pe.text_embeds.shape[0] + if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple): + current_batch_size = pe.text_embeds[0].shape[0] + else: + current_batch_size = pe.text_embeds.shape[0] if current_batch_size == batch_size: return pe if current_batch_size != 1: raise Exception("Can only expand batch size for batch size 1") - pe.text_embeds = pe.text_embeds.expand(batch_size, -1) + if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple): + pe.text_embeds = [t.expand(batch_size, -1) for t in pe.text_embeds] + else: + pe.text_embeds = pe.text_embeds.expand(batch_size, -1) if pe.pooled_embeds is not None: pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1) if pe.attention_mask is not None: @@ -145,7 +161,13 @@ class EncodedPromptPair: def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]): - text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) + if isinstance(prompt_embeds[0].text_embeds, list) or isinstance(prompt_embeds[0].text_embeds, tuple): + embed_list = [] + for i in range(len(prompt_embeds[0].text_embeds)): + embed_list.append(torch.cat([p.text_embeds[i] for p in prompt_embeds], dim=0)) + text_embeds = embed_list + else: + text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) pooled_embeds = None if prompt_embeds[0].pooled_embeds is not None: pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0) @@ -196,7 +218,16 @@ def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[Prom if num_parts is None: # use batch size num_parts = concatenated.text_embeds.shape[0] - text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0) + + if isinstance(concatenated.text_embeds, list) or isinstance(concatenated.text_embeds, tuple): + # split each part + text_embeds_splits = [ + torch.chunk(text, num_parts, dim=0) + for text in concatenated.text_embeds + ] + text_embeds_splits = list(zip(*text_embeds_splits)) + else: + text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0) if concatenated.pooled_embeds is not None: pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0)