diff --git a/config/examples/train_lora_hidream_48.yaml b/config/examples/train_lora_hidream_48.yaml new file mode 100644 index 00000000..f915f4c7 --- /dev/null +++ b/config/examples/train_lora_hidream_48.yaml @@ -0,0 +1,112 @@ +# HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train. +# It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM +# I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized. +# HiDream has a mixture of experts that may take special training considerations that I do not +# have implemented properly. The current implementation seems to work well for LoRA training, but +# may not be effective for longer training runs. The implementation could change in future updates +# so your results may vary when this happens. + +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_hidream_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 32 + linear_alpha: 32 + network_kwargs: + # it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt. + # proper training of it is not fully implemented + ignore_if_contains: + - "ff_i.experts" + - "ff_i.gate" + save: + dtype: bfloat16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + resolution: [ 512, 768, 1024 ] # hidream enjoys multiple resolutions + train: + batch_size: 1 + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # wont work with hidream + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + timestep_type: shift # sigmoid, shift, linear + optimizer: "adamw8bit" + lr: 2e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Defaults off + ema_config: + use_ema: false + ema_decay: 0.99 + + # will probably need this if gpu supports it for hidream, other dtypes may not work correctly + dtype: bf16 + model: + # the transformer will get grabbed from this hf repo + # warning ONLY train on Full. The dev and fast models are distilled and will break + name_or_path: "HiDream-ai/HiDream-I1-Full" + # the extras will be grabbed from this hf repo. (text encoder, vae) + extras_name_or_path: "HiDream-ai/HiDream-I1-Full" + arch: "hidream" + # both need to be quantized to train on 48GB currently + quantize: true + quantize_te: true + model_kwargs: + # llama is a gated model, It defaults to unsloth version, but you can set the llama path here + llama_model_path: "unsloth/Meta-Llama-3.1-8B-Instruct" + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/docker/Dockerfile b/docker/Dockerfile index bd6c1a9d..020f4c1d 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.6.3-base-ubuntu22.04 +FROM nvidia/cuda:12.6.3-devel-ubuntu22.04 LABEL authors="jaret" @@ -58,7 +58,8 @@ RUN echo "Cache bust: ${CACHEBUST}" && \ WORKDIR /app/ai-toolkit # Install Python dependencies -RUN pip install --no-cache-dir -r requirements.txt +RUN pip install --no-cache-dir -r requirements.txt && \ + pip install flash-attn --no-build-isolation --no-cache-dir # Build UI WORKDIR /app/ai-toolkit/ui diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 90724450..c775be3b 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -1,6 +1,7 @@ from .chroma import ChromaModel +from .hidream import HidreamModel AI_TOOLKIT_MODELS = [ # put a list of models here - ChromaModel + ChromaModel, HidreamModel ] diff --git a/extensions_built_in/diffusion_models/hidream/__init__.py b/extensions_built_in/diffusion_models/hidream/__init__.py new file mode 100644 index 00000000..1af6f465 --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/__init__.py @@ -0,0 +1 @@ +from .hidream_model import HidreamModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/hidream/hidream_model.py b/extensions_built_in/diffusion_models/hidream/hidream_model.py new file mode 100644 index 00000000..35b4a192 --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/hidream_model.py @@ -0,0 +1,445 @@ +import os +from typing import TYPE_CHECKING, List, Optional + +import einops +import torch +import torchvision +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from diffusers import AutoencoderKL, TorchAoConfig +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.mask import generate_random_mask, random_dialate_mask +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer, TorchAoConfig as TorchAoConfigTransformers +from .src.pipelines.hidream_image.pipeline_hidream_image import HiDreamImagePipeline +from .src.models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel +from .src.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler +from transformers import LlamaForCausalLM, PreTrainedTokenizerFast +from einops import rearrange, repeat +import random +import torch.nn.functional as F +from tqdm import tqdm +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5Tokenizer, + LlamaForCausalLM, + PreTrainedTokenizerFast +) + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "num_train_timesteps": 1000, + "shift": 3.0 +} + +# LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" +LLAMA_MODEL_PATH = "unsloth/Meta-Llama-3.1-8B-Instruct" +BASE_MODEL_PATH = "HiDream-ai/HiDream-I1-Full" + + +class HidreamModel(BaseModel): + arch = "hidream" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['HiDreamImageTransformer2DModel'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 + + def load_model(self): + dtype = self.torch_dtype + # HiDream-ai/HiDream-I1-Full + self.print_and_status_update("Loading HiDream model") + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + extras_path = self.model_config.extras_name_or_path + + llama_model_path = self.model_config.model_kwargs.get('llama_model_path', LLAMA_MODEL_PATH) + + scheduler = HidreamModel.get_train_scheduler() + + self.print_and_status_update("Loading llama 8b model") + + tokenizer_4 = PreTrainedTokenizerFast.from_pretrained( + llama_model_path, + use_fast=False + ) + + text_encoder_4 = LlamaForCausalLM.from_pretrained( + llama_model_path, + output_hidden_states=True, + output_attentions=True, + torch_dtype=torch.bfloat16, + ) + text_encoder_4.to(self.device_torch, dtype=dtype) + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing llama 8b model") + quantization_type = get_qtype(self.model_config.qtype_te) + quantize(text_encoder_4, weights=quantization_type) + freeze(text_encoder_4) + + if self.low_vram: + # unload it for now + text_encoder_4.to('cpu') + + flush() + + self.print_and_status_update("Loading transformer") + + transformer = HiDreamImageTransformer2DModel.from_pretrained( + model_path, + subfolder="transformer", + torch_dtype=torch.bfloat16 + ) + + if not self.low_vram: + transformer.to(self.device_torch, dtype=dtype) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing transformer") + quantization_type = get_qtype(self.model_config.qtype) + if self.low_vram: + # move and quantize only certain pieces at a time. + all_blocks = list(transformer.double_stream_blocks) + list(transformer.single_stream_blocks) + self.print_and_status_update(" - quantizing transformer blocks") + for block in tqdm(all_blocks): + block.to(self.device_torch, dtype=dtype) + quantize(block, weights=quantization_type) + freeze(block) + block.to('cpu') + # flush() + + self.print_and_status_update(" - quantizing extras") + transformer.to(self.device_torch, dtype=dtype) + quantize(transformer, weights=quantization_type) + freeze(transformer) + else: + quantize(transformer, weights=quantization_type) + freeze(transformer) + + if self.low_vram: + # unload it for now + transformer.to('cpu') + + flush() + + self.print_and_status_update("Loading vae") + + vae = AutoencoderKL.from_pretrained( + extras_path, + subfolder="vae", + torch_dtype=torch.bfloat16 + ).to(self.device_torch, dtype=dtype) + + + self.print_and_status_update("Loading clip encoders") + + text_encoder = CLIPTextModelWithProjection.from_pretrained( + extras_path, + subfolder="text_encoder", + torch_dtype=torch.bfloat16 + ).to(self.device_torch, dtype=dtype) + + tokenizer = CLIPTokenizer.from_pretrained( + extras_path, + subfolder="tokenizer" + ) + + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + extras_path, + subfolder="text_encoder_2", + torch_dtype=torch.bfloat16 + ).to(self.device_torch, dtype=dtype) + + tokenizer_2 = CLIPTokenizer.from_pretrained( + extras_path, + subfolder="tokenizer_2" + ) + + flush() + self.print_and_status_update("Loading T5 encoders") + + text_encoder_3 = T5EncoderModel.from_pretrained( + extras_path, + subfolder="text_encoder_3", + torch_dtype=torch.bfloat16 + ).to(self.device_torch, dtype=dtype) + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantization_type = get_qtype(self.model_config.qtype_te) + quantize(text_encoder_3, weights=quantization_type) + freeze(text_encoder_3) + flush() + + tokenizer_3 = T5Tokenizer.from_pretrained( + extras_path, + subfolder="tokenizer_3" + ) + flush() + + if self.low_vram: + self.print_and_status_update("Moving ecerything to device") + # move it all back + transformer.to(self.device_torch, dtype=dtype) + vae.to(self.device_torch, dtype=dtype) + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder_2.to(self.device_torch, dtype=dtype) + text_encoder_4.to(self.device_torch, dtype=dtype) + text_encoder_3.to(self.device_torch, dtype=dtype) + + # set to eval mode + # transformer.eval() + vae.eval() + text_encoder.eval() + text_encoder_2.eval() + text_encoder_4.eval() + text_encoder_3.eval() + + pipe = HiDreamImagePipeline( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + text_encoder_3=text_encoder_3, + tokenizer_3=tokenizer_3, + text_encoder_4=text_encoder_4, + tokenizer_4=tokenizer_4, + transformer=transformer, + ) + + flush() + + text_encoder_list = [text_encoder, text_encoder_2, text_encoder_3, text_encoder_4] + tokenizer_list = [tokenizer, tokenizer_2, tokenizer_3, tokenizer_4] + + for te in text_encoder_list: + # set the dtype + te.to(self.device_torch, dtype=dtype) + # freeze the model + freeze(te) + # set to eval mode + te.eval() + # set the requires grad to false + te.requires_grad_(False) + + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder_list # list of text encoders + self.tokenizer = tokenizer_list # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=3.0, + use_dynamic_shifting=False + ) + + pipeline: HiDreamImagePipeline = HiDreamImagePipeline( + scheduler=scheduler, + vae=self.vae, + text_encoder=self.text_encoder[0], + tokenizer=self.tokenizer[0], + text_encoder_2=self.text_encoder[1], + tokenizer_2=self.tokenizer[1], + text_encoder_3=self.text_encoder[2], + tokenizer_3=self.tokenizer[2], + text_encoder_4=self.text_encoder[3], + tokenizer_4=self.tokenizer[3], + transformer=unwrap_model(self.model), + aggressive_unloading=self.low_vram + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: HiDreamImagePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + with torch.no_grad(): + if latent_model_input.shape[-2] != latent_model_input.shape[-1]: + B, C, H, W = latent_model_input.shape + pH, pW = H // self.model.config.patch_size, W // self.model.config.patch_size + + img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) + img_ids = torch.zeros(pH, pW, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] + img_ids = img_ids.reshape(pH * pW, -1) + img_ids_pad = torch.zeros(self.transformer.max_seq, 3) + img_ids_pad[:pH*pW, :] = img_ids + + img_sizes = img_sizes.unsqueeze(0).to(latent_model_input.device) + img_ids = img_ids_pad.unsqueeze(0).to(latent_model_input.device) + else: + img_sizes = img_ids = None + + dtype = self.model.dtype + device = self.device_torch + + # Pack the latent + if latent_model_input.shape[-2] != latent_model_input.shape[-1]: + B, C, H, W = latent_model_input.shape + patch_size = self.transformer.config.patch_size + pH, pW = H // patch_size, W // patch_size + out = torch.zeros( + (B, C, self.transformer.max_seq, patch_size * patch_size), + dtype=latent_model_input.dtype, + device=latent_model_input.device + ) + latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) + out[:, :, 0:pH*pW] = latent_model_input + latent_model_input = out + + text_embeds = text_embeddings.text_embeds + # run the to for the list + text_embeds = [te.to(device, dtype=dtype) for te in text_embeds] + + noise_pred = self.transformer( + hidden_states = latent_model_input, + timesteps = timestep, + encoder_hidden_states = text_embeds, + pooled_embeds = text_embeddings.pooled_embeds.to(device, dtype=dtype), + img_sizes = img_sizes, + img_ids = img_ids, + return_dict = False, + )[0] + noise_pred = -noise_pred + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + self.text_encoder_to(self.device_torch, dtype=self.torch_dtype) + max_sequence_length = 128 + prompt_embeds, pooled_prompt_embeds = self.pipeline._encode_prompt( + prompt = prompt, + prompt_2 = prompt, + prompt_3 = prompt, + prompt_4 = prompt, + device = self.device_torch, + dtype = self.torch_dtype, + num_images_per_prompt = 1, + max_sequence_length = max_sequence_length, + ) + pe = PromptEmbeds( + [prompt_embeds, pooled_prompt_embeds] + ) + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return self.model.double_stream_blocks[0].block.attn1.to_q.weight.requires_grad + + def get_te_has_grad(self): + # assume no one wants to finetune 4 text encoders. + return False + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: HiDreamImageTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ['double_stream_blocks', 'single_stream_blocks'] + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + diff --git a/extensions_built_in/diffusion_models/hidream/src/__init__.py b/extensions_built_in/diffusion_models/hidream/src/__init__.py new file mode 100644 index 00000000..b64a05bb --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/src/__init__.py @@ -0,0 +1,2 @@ +from .models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel +from .pipelines.hidream_image.pipeline_hidream_image import HiDreamImagePipeline diff --git a/extensions_built_in/diffusion_models/hidream/src/models/attention.py b/extensions_built_in/diffusion_models/hidream/src/models/attention.py new file mode 100644 index 00000000..95f338cb --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/src/models/attention.py @@ -0,0 +1,106 @@ +import torch +from torch import nn +from typing import Optional +from diffusers.models.attention_processor import Attention +from diffusers.utils.torch_utils import maybe_allow_in_graph + +@maybe_allow_in_graph +class HiDreamAttention(Attention): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + upcast_softmax: bool = False, + scale_qk: bool = True, + eps: float = 1e-5, + processor = None, + out_dim: int = None, + single: bool = False + ): + super(Attention, self).__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.out_dim = out_dim if out_dim is not None else query_dim + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.single = single + + linear_cls = nn.Linear + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim) + self.to_k = linear_cls(self.inner_dim, self.inner_dim) + self.to_v = linear_cls(self.inner_dim, self.inner_dim) + self.to_out = linear_cls(self.inner_dim, self.out_dim) + self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps) + self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps) + + if not single: + self.to_q_t = linear_cls(query_dim, self.inner_dim) + self.to_k_t = linear_cls(self.inner_dim, self.inner_dim) + self.to_v_t = linear_cls(self.inner_dim, self.inner_dim) + self.to_out_t = linear_cls(self.inner_dim, self.out_dim) + self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) + self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) + + self.set_processor(processor) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward( + self, + norm_image_tokens: torch.FloatTensor, + image_tokens_masks: torch.FloatTensor = None, + norm_text_tokens: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.Tensor: + return self.processor( + self, + image_tokens = norm_image_tokens, + image_tokens_masks = image_tokens_masks, + text_tokens = norm_text_tokens, + rope = rope, + ) + +class FeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of + ) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/hidream/src/models/attention_processor.py b/extensions_built_in/diffusion_models/hidream/src/models/attention_processor.py new file mode 100644 index 00000000..abc2ff40 --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/src/models/attention_processor.py @@ -0,0 +1,121 @@ +from typing import Optional +import torch +from .attention import HiDreamAttention + +# Try to import Flash Attention first +flash_attn_available = False +try: + from flash_attn_interface import flash_attn_func + USE_FLASH_ATTN3 = True + flash_attn_available = True +except ImportError: + try: + from flash_attn import flash_attn_func + USE_FLASH_ATTN3 = False + flash_attn_available = True + except ImportError: + USE_FLASH_ATTN3 = False + flash_attn_available = False + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py +def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): + if flash_attn_available: + if USE_FLASH_ATTN3: + hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0] + else: + hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False) + else: + # Use torch's scaled dot-product attention as fallback + # Reshape for torch.nn.functional.scaled_dot_product_attention which expects [batch, heads, seq_len, head_dim] + query = query.transpose(1, 2) # [batch, heads, seq_len, head_dim] + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + + # Restore original shape + hidden_states = hidden_states.transpose(1, 2) # [batch, seq_len, heads, head_dim] + + hidden_states = hidden_states.flatten(-2) + hidden_states = hidden_states.to(query.dtype) + return hidden_states + +class HiDreamAttnProcessor_flashattn: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __call__( + self, + attn: HiDreamAttention, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + dtype = image_tokens.dtype + batch_size = image_tokens.shape[0] + + query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) + key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) + value_i = attn.to_v(image_tokens) + + inner_dim = key_i.shape[-1] + head_dim = inner_dim // attn.heads + + query_i = query_i.view(batch_size, -1, attn.heads, head_dim) + key_i = key_i.view(batch_size, -1, attn.heads, head_dim) + value_i = value_i.view(batch_size, -1, attn.heads, head_dim) + if image_tokens_masks is not None: + key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) + + if not attn.single: + query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) + key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) + value_t = attn.to_v_t(text_tokens) + + query_t = query_t.view(batch_size, -1, attn.heads, head_dim) + key_t = key_t.view(batch_size, -1, attn.heads, head_dim) + value_t = value_t.view(batch_size, -1, attn.heads, head_dim) + + num_image_tokens = query_i.shape[1] + num_text_tokens = query_t.shape[1] + query = torch.cat([query_i, query_t], dim=1) + key = torch.cat([key_i, key_t], dim=1) + value = torch.cat([value_i, value_t], dim=1) + else: + query = query_i + key = key_i + value = value_i + + if query.shape[-1] == rope.shape[-3] * 2: + query, key = apply_rope(query, key, rope) + else: + query_1, query_2 = query.chunk(2, dim=-1) + key_1, key_2 = key.chunk(2, dim=-1) + query_1, key_1 = apply_rope(query_1, key_1, rope) + query = torch.cat([query_1, query_2], dim=-1) + key = torch.cat([key_1, key_2], dim=-1) + + hidden_states = attention(query, key, value) + + if not attn.single: + hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) + hidden_states_i = attn.to_out(hidden_states_i) + hidden_states_t = attn.to_out_t(hidden_states_t) + return hidden_states_i, hidden_states_t + else: + hidden_states = attn.to_out(hidden_states) + return hidden_states \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/hidream/src/models/embeddings.py b/extensions_built_in/diffusion_models/hidream/src/models/embeddings.py new file mode 100644 index 00000000..5f57b36e --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/src/models/embeddings.py @@ -0,0 +1,114 @@ +import torch +from torch import nn +from typing import List +from diffusers.models.embeddings import Timesteps, TimestepEmbedding + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class EmbedND(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(2) + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size=2, + in_channels=4, + out_channels=1024, + ): + super().__init__() + self.patch_size = patch_size + self.out_channels = out_channels + self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, latent): + latent = self.proj(latent) + return latent + +class PooledEmbed(nn.Module): + def __init__(self, text_emb_dim, hidden_size): + super().__init__() + self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, pooled_embed): + return self.pooled_embedder(pooled_embed) + +class TimestepEmbed(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, timesteps, wdtype): + t_emb = self.time_proj(timesteps).to(dtype=wdtype) + t_emb = self.timestep_embedder(t_emb) + return t_emb + +class OutEmbed(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, adaln_input): + shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) + x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + x = self.linear(x) + return x \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/hidream/src/models/moe.py b/extensions_built_in/diffusion_models/hidream/src/models/moe.py new file mode 100644 index 00000000..3b3b6ce2 --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/src/models/moe.py @@ -0,0 +1,157 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from .attention import FeedForwardSwiGLU +from torch.distributed.nn.functional import all_gather + +_LOAD_BALANCING_LOSS = [] +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + +def batched_load_balancing_loss(): + aux_losses_arr = get_load_balancing_loss() + alpha = aux_losses_arr[0][-1] + Pi = torch.stack([ent[1] for ent in aux_losses_arr], dim=0) + fi = torch.stack([ent[2] for ent in aux_losses_arr], dim=0) + + fi_list = all_gather(fi) + fi = torch.stack(fi_list, 0).mean(0) + + aux_loss = (Pi * fi).sum(-1).mean() * alpha + return aux_loss + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MoEGate(nn.Module): + def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01): + super().__init__() + self.top_k = num_activated_experts + self.n_routed_experts = num_routed_experts + + self.scoring_func = 'softmax' + self.alpha = aux_loss_alpha + self.seq_aux = False + + # topk selection algorithm + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + # print(bsz, seq_len, h) + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, self.weight, None) + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') + + ### select top-k experts + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + # this was in original and memory leaks, not needed + + # ### expert-level computation auxiliary loss + # if self.training and self.alpha > 0.0: + # scores_for_aux = scores + # aux_topk = self.top_k + # # always compute aux loss based on the naive greedy topk method + # topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + # if self.seq_aux: + # scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + # ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) + # ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts) + # aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha + # else: + # mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + # ce = mask_ce.float().mean(0) + + # Pi = scores_for_aux.mean(0) + # fi = ce * self.n_routed_experts + # aux_loss = (Pi * fi).sum() * self.alpha + # save_load_balancing_loss((aux_loss, Pi, fi, self.alpha)) + # else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MOEFeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_routed_experts: int, + num_activated_experts: int, + ): + super().__init__() + self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2) + self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]) + self.gate = MoEGate( + embed_dim = dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts + ) + self.num_activated_experts = num_activated_experts + + def forward(self, x): + wtype = x.dtype + identity = x + orig_shape = x.shape + topk_idx, topk_weight, aux_loss = self.gate(x) + x = x.view(-1, x.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + # this was in original and memory leaks, not needed + # if self.training: + # x = x.repeat_interleave(self.num_activated_experts, dim=0) + # y = torch.empty_like(x, dtype=wtype) + # for i, expert in enumerate(self.experts): + # y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) + # y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + # y = y.view(*orig_shape).to(dtype=wtype) + # #y = AddAuxiliaryLoss.apply(y, aux_loss) + # else: + # y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + y = y + self.shared_experts(identity) + return y + + # @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.num_activated_experts + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # for fp16 and other dtype + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') + return expert_cache diff --git a/extensions_built_in/diffusion_models/hidream/src/models/transformers/transformer_hidream_image.py b/extensions_built_in/diffusion_models/hidream/src/models/transformers/transformer_hidream_image.py new file mode 100644 index 00000000..f7eb1045 --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/src/models/transformers/transformer_hidream_image.py @@ -0,0 +1,506 @@ +from typing import Any, Callable, Dict, Optional, Tuple, List + +import torch +import torch.nn as nn +import einops +from einops import repeat + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from ..embeddings import PatchEmbed, PooledEmbed, TimestepEmbed, EmbedND, OutEmbed +from ..attention import HiDreamAttention, FeedForwardSwiGLU +from ..attention_processor import HiDreamAttnProcessor_flashattn +from ..moe import MOEFeedForwardSwiGLU + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class TextProjection(nn.Module): + def __init__(self, in_features, hidden_size): + super().__init__() + self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False) + + def forward(self, caption): + hidden_states = self.linear(caption) + return hidden_states + +class BlockType: + TransformerBlock = 1 + SingleTransformerBlock = 2 + +@maybe_allow_in_graph +class HiDreamImageSingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2 + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True) + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + # 1. Attention + self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = True + ) + + # 3. Feed-forward + self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + attn_output_i = self.attn1( + norm_image_tokens, + image_tokens_masks, + rope = rope, + ) + image_tokens = gate_msa_i * attn_output_i + image_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype)) + image_tokens = ff_output_i + image_tokens + return image_tokens + +@maybe_allow_in_graph +class HiDreamImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2 + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 12 * dim, bias=True) + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + # 1. Attention + self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.norm1_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = False + ) + + # 3. Feed-forward + self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) + self.norm3_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ + shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t + + attn_output_i, attn_output_t = self.attn1( + norm_image_tokens, + image_tokens_masks, + norm_text_tokens, + rope = rope, + ) + + image_tokens = gate_msa_i * attn_output_i + image_tokens + text_tokens = gate_msa_t * attn_output_t + text_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t + + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens) + ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens) + image_tokens = ff_output_i + image_tokens + text_tokens = ff_output_t + text_tokens + return image_tokens, text_tokens + +@maybe_allow_in_graph +class HiDreamImageBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + block_type: BlockType = BlockType.TransformerBlock, + ): + super().__init__() + block_classes = { + BlockType.TransformerBlock: HiDreamImageTransformerBlock, + BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock, + } + self.block = block_classes[block_type]( + dim, + num_attention_heads, + attention_head_dim, + num_routed_experts, + num_activated_experts + ) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + return self.block( + image_tokens, + image_tokens_masks, + text_tokens, + adaln_input, + rope, + ) + +class HiDreamImageTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin +): + _supports_gradient_checkpointing = True + _no_split_modules = ["HiDreamImageBlock"] + + @register_to_config + def __init__( + self, + patch_size: Optional[int] = None, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 16, + num_single_layers: int = 32, + attention_head_dim: int = 128, + num_attention_heads: int = 20, + caption_channels: List[int] = None, + text_emb_dim: int = 2048, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + axes_dims_rope: Tuple[int, int] = (32, 32), + max_resolution: Tuple[int, int] = (128, 128), + llama_layers: List[int] = None, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.llama_layers = llama_layers + + self.t_embedder = TimestepEmbed(self.inner_dim) + self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim) + self.x_embedder = PatchEmbed( + patch_size = patch_size, + in_channels = in_channels, + out_channels = self.inner_dim, + ) + self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope) + + self.double_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.config.num_attention_heads, + attention_head_dim = self.config.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.TransformerBlock + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.config.num_attention_heads, + attention_head_dim = self.config.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.SingleTransformerBlock + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels) + + caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] + caption_projection = [] + for caption_channel in caption_channels: + caption_projection.append(TextProjection(in_features = caption_channel, hidden_size = self.inner_dim)) + self.caption_projection = nn.ModuleList(caption_projection) + self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) + + self.gradient_checkpointing = False + + + def expand_timesteps(self, timesteps, batch_size, device): + if not torch.is_tensor(timesteps): + is_mps = device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(batch_size) + return timesteps + + # the implementation on hidream during train was wrong, just use the inference one. + def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: + # Process all images in the batch according to their specific dimensions + x_arr = [] + for i, img_size in enumerate(img_sizes): + pH, pW = img_size + x_arr.append( + einops.rearrange( + x[i, :pH*pW].reshape(1, pH, pW, -1), + 'B H W (p1 p2 C) -> B C (H p1) (W p2)', + p1=self.config.patch_size, p2=self.config.patch_size + ) + ) + x = torch.cat(x_arr, dim=0) + return x + + def patchify(self, x, max_seq, img_sizes=None): + pz2 = self.config.patch_size * self.config.patch_size + if isinstance(x, torch.Tensor): + B, C = x.shape[0], x.shape[1] + device = x.device + dtype = x.dtype + else: + B, C = len(x), x[0].shape[0] + device = x[0].device + dtype = x[0].dtype + x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) + + if img_sizes is not None: + for i, img_size in enumerate(img_sizes): + x_masks[i, 0:img_size[0] * img_size[1]] = 1 + x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2) + elif isinstance(x, torch.Tensor): + pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size + x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.config.patch_size, p2=self.config.patch_size) + img_sizes = [[pH, pW]] * B + x_masks = None + else: + raise NotImplementedError + return x, x_masks, img_sizes + + def forward( + self, + hidden_states: torch.Tensor, + timesteps: torch.LongTensor = None, + encoder_hidden_states: torch.Tensor = None, + pooled_embeds: torch.Tensor = None, + img_sizes: Optional[List[Tuple[int, int]]] = None, + img_ids: Optional[torch.Tensor] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # spatial forward + batch_size = hidden_states.shape[0] + hidden_states_type = hidden_states.dtype + + # 0. time + timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) + timesteps = self.t_embedder(timesteps, hidden_states_type) + p_embedder = self.p_embedder(pooled_embeds) + adaln_input = timesteps + p_embedder + + hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) + if image_tokens_masks is None: + pH, pW = img_sizes[0] + img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + hidden_states = self.x_embedder(hidden_states) + + T5_encoder_hidden_states = encoder_hidden_states[0] + encoder_hidden_states = encoder_hidden_states[-1] + encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] + + if self.caption_projection is not None: + new_encoder_hidden_states = [] + for i, enc_hidden_state in enumerate(encoder_hidden_states): + enc_hidden_state = self.caption_projection[i](enc_hidden_state) + enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) + new_encoder_hidden_states.append(enc_hidden_state) + encoder_hidden_states = new_encoder_hidden_states + T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) + T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states.append(T5_encoder_hidden_states) + + txt_ids = torch.zeros( + batch_size, + encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], + 3, + device=img_ids.device, dtype=img_ids.dtype + ) + ids = torch.cat((img_ids, txt_ids), dim=1) + rope = self.pe_embedder(ids) + + # 2. Blocks + block_id = 0 + initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) + initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] + for bid, block in enumerate(self.double_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach() + cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + image_tokens_masks, + cur_encoder_hidden_states, + adaln_input.clone(), + rope.clone(), + ) + + else: + hidden_states, initial_encoder_hidden_states = block( + image_tokens = hidden_states, + image_tokens_masks = image_tokens_masks, + text_tokens = cur_encoder_hidden_states, + adaln_input = adaln_input, + rope = rope, + ) + initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] + block_id += 1 + + image_tokens_seq_len = hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) + hidden_states_seq_len = hidden_states.shape[1] + if image_tokens_masks is not None: + encoder_attention_mask_ones = torch.ones( + (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), + device=image_tokens_masks.device, dtype=image_tokens_masks.dtype + ) + image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) + + for bid, block in enumerate(self.single_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach() + hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + image_tokens_masks, + None, + adaln_input.clone(), + rope.clone(), + ) + else: + hidden_states = block( + image_tokens = hidden_states, + image_tokens_masks = image_tokens_masks, + text_tokens = None, + adaln_input = adaln_input, + rope = rope, + ) + hidden_states = hidden_states[:, :hidden_states_seq_len] + block_id += 1 + + hidden_states = hidden_states[:, :image_tokens_seq_len, ...] + output = self.final_layer(hidden_states, adaln_input) + output = self.unpatchify(output, img_sizes, self.training) + if image_tokens_masks is not None: + image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len] + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output, image_tokens_masks) + return Transformer2DModelOutput(sample=output, mask=image_tokens_masks) + diff --git a/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image.py b/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image.py new file mode 100644 index 00000000..9c4e51fa --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image.py @@ -0,0 +1,737 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +import math +import einops +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5Tokenizer, + LlamaForCausalLM, + PreTrainedTokenizerFast +) + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from .pipeline_output import HiDreamImagePipelineOutput +from ...models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel +from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5Tokenizer, + text_encoder_4: LlamaForCausalLM, + tokenizer_4: PreTrainedTokenizerFast, + transformer: HiDreamImageTransformer2DModel, + aggressive_unloading: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + text_encoder_4=text_encoder_4, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + tokenizer_4=tokenizer_4, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + self.tokenizer_4.pad_token = self.tokenizer_4.eos_token + self.aggressive_unloading = aggressive_unloading + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_3.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=min(max_sequence_length, 218), + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {218} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def _get_llama3_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_4.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_4( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" + ) + + outputs = self.text_encoder_4( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + output_hidden_states=True, + output_attentions=True + ) + + prompt_embeds = outputs.hidden_states[1:] + prompt_embeds = torch.stack(prompt_embeds, dim=0) + _, _, seq_len, dim = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + prompt_4: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + negative_prompt_4: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + lora_scale: Optional[float] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] + + prompt_embeds, pooled_prompt_embeds = self._encode_prompt( + prompt = prompt, + prompt_2 = prompt_2, + prompt_3 = prompt_3, + prompt_4 = prompt_4, + device = device, + dtype = dtype, + num_images_per_prompt = num_images_per_prompt, + prompt_embeds = prompt_embeds, + pooled_prompt_embeds = pooled_prompt_embeds, + max_sequence_length = max_sequence_length, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + negative_prompt_4 = negative_prompt_4 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + negative_prompt_4 = ( + batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( + prompt = negative_prompt, + prompt_2 = negative_prompt_2, + prompt_3 = negative_prompt_3, + prompt_4 = negative_prompt_4, + device = device, + dtype = dtype, + num_images_per_prompt = num_images_per_prompt, + prompt_embeds = negative_prompt_embeds, + pooled_prompt_embeds = negative_pooled_prompt_embeds, + max_sequence_length = max_sequence_length, + ) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + prompt_4: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + ): + device = device or self._execution_device + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_4 = prompt_4 or prompt + prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 + + pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, + self.text_encoder, + prompt = prompt, + num_images_per_prompt = num_images_per_prompt, + max_sequence_length = max_sequence_length, + device = device, + dtype = dtype, + ) + + pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, + self.text_encoder_2, + prompt = prompt_2, + num_images_per_prompt = num_images_per_prompt, + max_sequence_length = max_sequence_length, + device = device, + dtype = dtype, + ) + + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) + + t5_prompt_embeds = self._get_t5_prompt_embeds( + prompt = prompt_3, + num_images_per_prompt = num_images_per_prompt, + max_sequence_length = max_sequence_length, + device = device, + dtype = dtype + ) + llama3_prompt_embeds = self._get_llama3_prompt_embeds( + prompt = prompt_4, + num_images_per_prompt = num_images_per_prompt, + max_sequence_length = max_sequence_length, + device = device, + dtype = dtype + ) + prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] + + return prompt_embeds, pooled_prompt_embeds + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + prompt_4: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + negative_prompt_4: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + ): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + division = self.vae_scale_factor * 2 + S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 + scale = S_max / (width * height) + scale = math.sqrt(scale) + width, height = int(width * scale // division * division), int(height * scale // division * division) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + prompt_4=prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds_arr = [] + for n, p in zip(negative_prompt_embeds, prompt_embeds): + if len(n.shape) == 3: + prompt_embeds_arr.append(torch.cat([n, p], dim=0)) + else: + prompt_embeds_arr.append(torch.cat([n, p], dim=1)) + prompt_embeds = prompt_embeds_arr + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + pooled_prompt_embeds.dtype, + device, + generator, + latents, + ) + + if latents.shape[-2] != latents.shape[-1]: + B, C, H, W = latents.shape + pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size + + img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) + img_ids = torch.zeros(pH, pW, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] + img_ids = img_ids.reshape(pH * pW, -1) + img_ids_pad = torch.zeros(self.transformer.max_seq, 3) + img_ids_pad[:pH*pW, :] = img_ids + + img_sizes = img_sizes.unsqueeze(0).to(latents.device) + img_ids = img_ids_pad.unsqueeze(0).to(latents.device) + if self.do_classifier_free_guidance: + img_sizes = img_sizes.repeat(2 * B, 1) + img_ids = img_ids.repeat(2 * B, 1, 1) + else: + img_sizes = img_ids = None + + # 5. Prepare timesteps + mu = calculate_shift(self.transformer.max_seq) + scheduler_kwargs = {"mu": mu} + if isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu)) + timesteps = self.scheduler.timesteps + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if latent_model_input.shape[-2] != latent_model_input.shape[-1]: + B, C, H, W = latent_model_input.shape + patch_size = self.transformer.config.patch_size + pH, pW = H // patch_size, W // patch_size + out = torch.zeros( + (B, C, self.transformer.max_seq, patch_size * patch_size), + dtype=latent_model_input.dtype, + device=latent_model_input.device + ) + latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) + out[:, :, 0:pH*pW] = latent_model_input + latent_model_input = out + + noise_pred = self.transformer( + hidden_states = latent_model_input, + timesteps = timestep, + encoder_hidden_states = prompt_embeds, + pooled_embeds = pooled_prompt_embeds, + img_sizes = img_sizes, + img_ids = img_ids, + return_dict = False, + )[0] + noise_pred = -noise_pred + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return HiDreamImagePipelineOutput(images=image) \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_output.py b/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_output.py new file mode 100644 index 00000000..b03e2fae --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class HiDreamImagePipelineOutput(BaseOutput): + """ + Output class for HiDreamImage pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/hidream/src/schedulers/flash_flow_match.py b/extensions_built_in/diffusion_models/hidream/src/schedulers/flash_flow_match.py new file mode 100644 index 00000000..122f8ed7 --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/src/schedulers/flash_flow_match.py @@ -0,0 +1,428 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput, is_scipy_available, logging +from diffusers.utils.torch_utils import randn_tensor + +if is_scipy_available(): + import scipy.stats + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlashFlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlashFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) + self.num_inference_steps = num_inference_steps + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlashFlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + + sigma = self.sigmas[self.step_index] + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + denoised = sample - model_output * sigma + + if self.step_index < self.num_inference_steps - 1: + sigma_next = self.sigmas[self.step_index + 1] + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=denoised.dtype, + ) + sample = sigma_next * noise + (1.0 - sigma_next) * denoised + + self._step_index += 1 + sample = sample.to(model_output.dtype) + + if not return_dict: + return (sample,) + + return FlashFlowMatchEulerDiscreteSchedulerOutput(prev_sample=sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def __len__(self): + return self.config.num_train_timesteps diff --git a/extensions_built_in/diffusion_models/hidream/src/schedulers/fm_solvers_unipc.py b/extensions_built_in/diffusion_models/hidream/src/schedulers/fm_solvers_unipc.py new file mode 100644 index 00000000..57321baa --- /dev/null +++ b/extensions_built_in/diffusion_models/hidream/src/schedulers/fm_solvers_unipc.py @@ -0,0 +1,800 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError( + " missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], + b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop( + "this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError( + " missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError( + " missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError( + " missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ + self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step(self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and + self.step_index - 1 not in self.disable_corrector and + self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output( + model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, + len(self.timesteps) - + self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, + self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/toolkit/guidance.py b/toolkit/guidance.py index 84242423..287d17e7 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -649,11 +649,13 @@ def targeted_flow_guidance( noise, timesteps ).detach() + unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch) conditional_noisy_latents = sd.add_noise( conditional_latents, noise, timesteps ).detach() + conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch) # disable the lora to get a baseline prediction sd.network.is_active = False diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index ddc8ce8b..61c4bcf6 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -725,9 +725,13 @@ class BaseModel: do_classifier_free_guidance = True # check if batch size of embeddings matches batch size of latents - if latents.shape[0] == text_embeddings.text_embeds.shape[0]: + if isinstance(text_embeddings.text_embeds, list): + te_batch_size = text_embeddings.text_embeds[0].shape[0] + else: + te_batch_size = text_embeddings.text_embeds.shape[0] + if latents.shape[0] == te_batch_size: do_classifier_free_guidance = False - elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: + elif latents.shape[0] * 2 != te_batch_size: raise ValueError( "Batch size of latents must be the same or half the batch size of text embeddings") latents = latents.to(self.device_torch) diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index eb213cca..ff5a68f3 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -36,7 +36,10 @@ class PromptEmbeds: self.attention_mask = attention_mask def to(self, *args, **kwargs): - self.text_embeds = self.text_embeds.to(*args, **kwargs) + if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple): + self.text_embeds = [t.to(*args, **kwargs) for t in self.text_embeds] + else: + self.text_embeds = self.text_embeds.to(*args, **kwargs) if self.pooled_embeds is not None: self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) if self.attention_mask is not None: @@ -45,7 +48,10 @@ class PromptEmbeds: def detach(self): new_embeds = self.clone() - new_embeds.text_embeds = new_embeds.text_embeds.detach() + if isinstance(new_embeds.text_embeds, list) or isinstance(new_embeds.text_embeds, tuple): + new_embeds.text_embeds = [t.detach() for t in new_embeds.text_embeds] + else: + new_embeds.text_embeds = new_embeds.text_embeds.detach() if new_embeds.pooled_embeds is not None: new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach() if new_embeds.attention_mask is not None: @@ -53,10 +59,14 @@ class PromptEmbeds: return new_embeds def clone(self): - if self.pooled_embeds is not None: - prompt_embeds = PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()]) + if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple): + cloned_text_embeds = [t.clone() for t in self.text_embeds] else: - prompt_embeds = PromptEmbeds(self.text_embeds.clone()) + cloned_text_embeds = self.text_embeds.clone() + if self.pooled_embeds is not None: + prompt_embeds = PromptEmbeds([cloned_text_embeds, self.pooled_embeds.clone()]) + else: + prompt_embeds = PromptEmbeds(cloned_text_embeds) if self.attention_mask is not None: prompt_embeds.attention_mask = self.attention_mask.clone() @@ -64,12 +74,18 @@ class PromptEmbeds: def expand_to_batch(self, batch_size): pe = self.clone() - current_batch_size = pe.text_embeds.shape[0] + if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple): + current_batch_size = pe.text_embeds[0].shape[0] + else: + current_batch_size = pe.text_embeds.shape[0] if current_batch_size == batch_size: return pe if current_batch_size != 1: raise Exception("Can only expand batch size for batch size 1") - pe.text_embeds = pe.text_embeds.expand(batch_size, -1) + if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple): + pe.text_embeds = [t.expand(batch_size, -1) for t in pe.text_embeds] + else: + pe.text_embeds = pe.text_embeds.expand(batch_size, -1) if pe.pooled_embeds is not None: pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1) if pe.attention_mask is not None: @@ -145,7 +161,13 @@ class EncodedPromptPair: def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]): - text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) + if isinstance(prompt_embeds[0].text_embeds, list) or isinstance(prompt_embeds[0].text_embeds, tuple): + embed_list = [] + for i in range(len(prompt_embeds[0].text_embeds)): + embed_list.append(torch.cat([p.text_embeds[i] for p in prompt_embeds], dim=0)) + text_embeds = embed_list + else: + text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) pooled_embeds = None if prompt_embeds[0].pooled_embeds is not None: pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0) @@ -196,7 +218,16 @@ def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[Prom if num_parts is None: # use batch size num_parts = concatenated.text_embeds.shape[0] - text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0) + + if isinstance(concatenated.text_embeds, list) or isinstance(concatenated.text_embeds, tuple): + # split each part + text_embeds_splits = [ + torch.chunk(text, num_parts, dim=0) + for text in concatenated.text_embeds + ] + text_embeds_splits = list(zip(*text_embeds_splits)) + else: + text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0) if concatenated.pooled_embeds is not None: pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0) diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 75bab27d..84d98330 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -283,7 +283,7 @@ export default function SimpleJob({ options={[ { value: 'sigmoid', label: 'Sigmoid' }, { value: 'linear', label: 'Linear' }, - { value: 'flux_shift', label: 'Flux Shift' }, + { value: 'shift', label: 'Shift' }, ]} /> { @@ -83,6 +84,20 @@ export const options = { 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, }, + { + name_or_path: 'HiDream-ai/HiDream-I1-Full', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.arch': ['hidream', defaultModelArch], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.lr': [0.0002, 0.0001], + 'config.process[0].train.timestep_type': ['shift', 'sigmoid'], + 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []], + }, + }, { name_or_path: 'ostris/objective-reality', dev_only: true, diff --git a/ui/src/types.ts b/ui/src/types.ts index 9841553e..fdac4d46 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -55,6 +55,9 @@ export interface NetworkConfig { linear_alpha: number; lokr_full_rank: boolean; lokr_factor: number; + network_kwargs: { + ignore_if_contains: string[]; + } } export interface SaveConfig { diff --git a/version.py b/version.py index bccbcfdf..88364324 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.2.5" \ No newline at end of file +VERSION = "0.2.6" \ No newline at end of file