From 19ea8ecc3815faccb464d58b2509e295e43902cd Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 25 Jun 2025 13:58:16 -0600 Subject: [PATCH 1/2] Added support for finetuning OmniGen2. --- config/examples/train_lora_omnigen2_24gb.yaml | 94 ++ .../diffusion_models/__init__.py | 3 +- .../diffusion_models/omnigen2/__init__.py | 327 +++++ .../diffusion_models/omnigen2/src/__init__.py | 0 .../omnigen2/src/models/__init__.py | 0 .../src/models/attention_processor.py | 357 +++++ .../omnigen2/src/models/embeddings.py | 126 ++ .../src/models/transformers/__init__.py | 3 + .../src/models/transformers/block_lumina2.py | 218 +++ .../src/models/transformers/components.py | 4 + .../omnigen2/src/models/transformers/repo.py | 135 ++ .../transformers/transformer_omnigen2.py | 621 ++++++++ .../omnigen2/src/ops/triton/__init__.py | 0 .../omnigen2/src/ops/triton/layer_norm.py | 1257 +++++++++++++++++ .../omnigen2/src/pipelines/__init__.py | 0 .../omnigen2/src/pipelines/image_processor.py | 266 ++++ .../pipelines/omnigen2/pipeline_omnigen2.py | 728 ++++++++++ .../omnigen2/pipeline_omnigen2_chat.py | 830 +++++++++++ .../omnigen2/src/pipelines/pipeline_utils.py | 62 + .../omnigen2/src/schedulers/__init__.py | 0 .../scheduling_dpmsolver_multistep.py | 1052 ++++++++++++++ .../scheduling_flow_match_euler_discrete.py | 229 +++ .../omnigen2/src/utils/__init__.py | 0 .../omnigen2/src/utils/img_util.py | 31 + .../omnigen2/src/utils/import_utils.py | 46 + toolkit/extension.py | 6 +- ui/src/app/jobs/new/options.ts | 13 + version.py | 2 +- 28 files changed, 6405 insertions(+), 5 deletions(-) create mode 100644 config/examples/train_lora_omnigen2_24gb.yaml create mode 100644 extensions_built_in/diffusion_models/omnigen2/__init__.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/__init__.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/models/__init__.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/models/attention_processor.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/models/embeddings.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/models/transformers/__init__.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/models/transformers/block_lumina2.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/models/transformers/components.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/models/transformers/repo.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/models/transformers/transformer_omnigen2.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/ops/triton/__init__.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/ops/triton/layer_norm.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/pipelines/__init__.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/pipelines/image_processor.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2_chat.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/pipelines/pipeline_utils.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/schedulers/__init__.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_dpmsolver_multistep.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_flow_match_euler_discrete.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/utils/__init__.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/utils/img_util.py create mode 100644 extensions_built_in/diffusion_models/omnigen2/src/utils/import_utils.py diff --git a/config/examples/train_lora_omnigen2_24gb.yaml b/config/examples/train_lora_omnigen2_24gb.yaml new file mode 100644 index 00000000..6eb15302 --- /dev/null +++ b/config/examples/train_lora_omnigen2_24gb.yaml @@ -0,0 +1,94 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_omnigen2_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # omnigen2 should work with multiple resolutions + train: + batch_size: 1 + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with omnigen2 + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + timestep_type: 'sigmoid' # sigmoid, linear, shift + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + + # ema will smooth out learning, but could slow it down. + # ema_config: + # use_ema: true + # ema_decay: 0.99 + + # will probably need this if gpu supports it for omnigen2, other dtypes may not work correctly + dtype: bf16 + model: + name_or_path: "OmniGen2/OmniGen2 + arch: "omnigen2" + quantize_te: true # quantize_only te + # quantize: true # quantize transformer + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # negative prompt, optional + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 0cc323b9..3faa87cc 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -1,8 +1,9 @@ from .chroma import ChromaModel from .hidream import HidreamModel from .f_light import FLiteModel +from .omnigen2 import OmniGen2Model AI_TOOLKIT_MODELS = [ # put a list of models here - ChromaModel, HidreamModel, FLiteModel + ChromaModel, HidreamModel, FLiteModel, OmniGen2Model ] diff --git a/extensions_built_in/diffusion_models/omnigen2/__init__.py b/extensions_built_in/diffusion_models/omnigen2/__init__.py new file mode 100644 index 00000000..5b20b387 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/__init__.py @@ -0,0 +1,327 @@ +import inspect +import os +from typing import TYPE_CHECKING, List, Optional + +import einops +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.models.base_model import BaseModel +from diffusers import AutoencoderKL +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype +from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline +from .src.models.transformers import OmniGen2Transformer2DModel +from .src.models.transformers.repo import OmniGen2RotaryPosEmbed +from .src.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler +from transformers import CLIPProcessor, Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "num_train_timesteps": 1000 +} + +BASE_MODEL_PATH = "OmniGen2/OmniGen2" + + +class OmniGen2Model(BaseModel): + arch = "omnigen2" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['OmniGen2Transformer2DModel'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 + + def load_model(self): + dtype = self.torch_dtype + # HiDream-ai/HiDream-I1-Full + self.print_and_status_update("Loading OmniGen2 model") + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + extras_path = self.model_config.extras_name_or_path + + scheduler = OmniGen2Model.get_train_scheduler() + + self.print_and_status_update("Loading Qwen2.5 VL") + processor = CLIPProcessor.from_pretrained( + extras_path, + subfolder="processor", + use_fast=True + ) + + mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( + extras_path, + subfolder="mllm", + torch_dtype=torch.bfloat16 + ) + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Qwen2.5 VL model") + quantization_type = get_qtype(self.model_config.qtype_te) + quantize(mllm, weights=quantization_type) + freeze(mllm) + + if self.low_vram: + # unload it for now + mllm.to('cpu') + + flush() + + self.print_and_status_update("Loading transformer") + + transformer = OmniGen2Transformer2DModel.from_pretrained( + model_path, + subfolder="transformer", + torch_dtype=torch.bfloat16 + ) + + if not self.low_vram: + transformer.to(self.device_torch, dtype=dtype) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing transformer") + quantization_type = get_qtype(self.model_config.qtype) + quantize(transformer, weights=quantization_type) + freeze(transformer) + + if self.low_vram: + # unload it for now + transformer.to('cpu') + + flush() + + self.print_and_status_update("Loading vae") + + vae = AutoencoderKL.from_pretrained( + extras_path, + subfolder="vae", + torch_dtype=torch.bfloat16 + ).to(self.device_torch, dtype=dtype) + + + flush() + self.print_and_status_update("Loading Qwen2.5 VLProcessor") + + flush() + + if self.low_vram: + self.print_and_status_update("Moving everything to device") + # move it all back + transformer.to(self.device_torch, dtype=dtype) + vae.to(self.device_torch, dtype=dtype) + mllm.to(self.device_torch, dtype=dtype) + + # set to eval mode + # transformer.eval() + vae.eval() + mllm.eval() + mllm.requires_grad_(False) + + pipe: OmniGen2Pipeline = OmniGen2Pipeline( + transformer=transformer, + vae=vae, + scheduler=scheduler, + mllm=mllm, + processor=processor, + ) + + # pipe: OmniGen2Pipeline = OmniGen2Pipeline.from_pretrained( + # model_path, + # transformer=transformer, + # vae=vae, + # scheduler=scheduler, + # mllm=mllm, + # trust_remote_code=True, + # ) + # processor = pipe.processor + + flush() + + text_encoder_list = [mllm] + tokenizer_list = [processor] + + + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder_list # list of text encoders + self.tokenizer = tokenizer_list # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + + self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( + transformer.config.axes_dim_rope, + transformer.config.axes_lens, + theta=10000, + ) + + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = OmniFlowMatchEuler( + dynamic_time_shift=True, + num_train_timesteps=1000 + ) + + pipeline: OmniGen2Pipeline = OmniGen2Pipeline( + transformer=self.model, + vae=self.vae, + scheduler=scheduler, + mllm=self.text_encoder[0], + processor=self.tokenizer[0], + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: OmniGen2Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_attention_mask=conditional_embeds.attention_mask, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_attention_mask=unconditional_embeds.attention_mask, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + text_guidance_scale=gen_config.guidance_scale, + image_guidance_scale=1.0, # reference image guidance scale. Add this for controls + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + # optional_kwargs = {} + # if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()): + # optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states + + timesteps = timestep / 1000 # convert to 0 to 1 scale + # timestep for model starts at 0 instead of 1. So we need to reverse them + timestep = 1 - timesteps + model_pred = self.model( + latent_model_input, + timestep, + text_embeddings.text_embeds, + self.freqs_cis, + text_embeddings.attention_mask, + ref_image_hidden_states=None, # todo add ref latent ability + ) + + return model_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt] + self.text_encoder_to(self.device_torch, dtype=self.torch_dtype) + max_sequence_length = 256 + prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt( + prompt = prompt, + do_classifier_free_guidance=False, + device=self.device_torch, + max_sequence_length=max_sequence_length, + ) + pe = PromptEmbeds(prompt_embeds) + pe.attention_mask = prompt_attention_mask + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return False + + def get_te_has_grad(self): + # assume no one wants to finetune 4 text encoders. + return False + + def save_model(self, output_path, meta, save_dtype): + # only save the transformer + transformer: OmniGen2Transformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + # return (noise - batch.latents).detach() + return (batch.latents - noise).detach() + + def get_transformer_block_names(self) -> Optional[List[str]]: + # omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers. + # lets do all but image refiner until we add it + return ['noise_refiner', 'context_refiner', 'layers'] + # return ['layers'] + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def get_base_model_version(self): + return "omnigen2" + diff --git a/extensions_built_in/diffusion_models/omnigen2/src/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/attention_processor.py b/extensions_built_in/diffusion_models/omnigen2/src/models/attention_processor.py new file mode 100644 index 00000000..1f713c75 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/models/attention_processor.py @@ -0,0 +1,357 @@ +""" +OmniGen2 Attention Processor Module + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import warnings +import math +from typing import Optional, Tuple, Dict, Any + +import torch +import torch.nn.functional as F +from einops import repeat + +from ..utils.import_utils import is_flash_attn_available + +if is_flash_attn_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +else: + warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance") + + +from diffusers.models.attention_processor import Attention +from .embeddings import apply_rotary_emb + + +class OmniGen2AttnProcessorFlash2Varlen: + """ + Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + + This processor implements: + - Flash attention with variable length sequences + - Rotary position embeddings (RoPE) + - Query-Key normalization + - Proportional attention scaling + + Args: + None + """ + + def __init__(self) -> None: + """Initialize the attention processor.""" + if not is_flash_attn_available(): + raise ImportError( + "OmniGen2AttnProcessorFlash2Varlen requires flash_attn. " + "Please install flash_attn." + ) + + def _upad_input( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + num_heads: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: + """ + Unpad the input tensors for flash attention. + + Args: + query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) + key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) + value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) + attention_mask: Attention mask tensor of shape (batch_size, seq_len) + query_length: Length of the query sequence + num_heads: Number of attention heads + + Returns: + Tuple containing: + - Unpadded query tensor + - Unpadded key tensor + - Unpadded value tensor + - Query indices + - Tuple of cumulative sequence lengths for query and key + - Tuple of maximum sequence lengths for query and key + """ + def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """Helper function to get unpadding data from attention mask.""" + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu_seqlens, max_seqlen_in_batch + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + # Unpad key and value layers + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + + # Handle different query length cases + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Process attention computation with flash attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor + attention_mask: Optional attention mask tensor + image_rotary_emb: Optional rotary embeddings for image tokens + base_sequence_length: Optional base sequence length for proportional attention + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # Unpad input for flash attention + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + # Handle different number of heads + if kv_heads < attn.heads: + key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + + # Apply flash attention + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + + # Pad output and apply final transformations + hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) + hidden_states = hidden_states.flatten(-2) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class OmniGen2AttnProcessor: + """ + Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + + This processor is optimized for PyTorch 2.0 and implements: + - Flash attention with variable length sequences + - Rotary position embeddings (RoPE) + - Query-Key normalization + - Proportional attention scaling + + Args: + None + + Raises: + ImportError: If PyTorch version is less than 2.0 + """ + + def __init__(self) -> None: + """Initialize the attention processor.""" + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " + "Please upgrade PyTorch to version 2.0 or later." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Process attention computation with flash attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor + attention_mask: Optional attention mask tensor + image_rotary_emb: Optional rotary embeddings for image tokens + base_sequence_length: Optional base sequence length for proportional attention + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + if attention_mask is not None: + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6 + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/embeddings.py b/extensions_built_in/diffusion_models/omnigen2/src/models/embeddings.py new file mode 100644 index 00000000..5282f2de --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/models/embeddings.py @@ -0,0 +1,126 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + + +from diffusers.models.activations import get_activation + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + self.initialize_weights() + + def initialize_weights(self): + nn.init.normal_(self.linear_1.weight, std=0.02) + nn.init.zeros_(self.linear_1.bias) + nn.init.normal_(self.linear_2.weight, std=0.02) + nn.init.zeros_(self.linear_2.bias) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen and CogView4 + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/__init__.py new file mode 100644 index 00000000..157de1d4 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/__init__.py @@ -0,0 +1,3 @@ +from .transformer_omnigen2 import OmniGen2Transformer2DModel + +__all__ = ["OmniGen2Transformer2DModel"] \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/block_lumina2.py b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/block_lumina2.py new file mode 100644 index 00000000..13739d3a --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/block_lumina2.py @@ -0,0 +1,218 @@ + +# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from diffusers.models.embeddings import Timesteps +from ..embeddings import TimestepEmbedding + +from ...utils.import_utils import is_flash_attn_available, is_triton_available + +if is_triton_available(): + from ...ops.triton.layer_norm import RMSNorm +else: + from torch.nn import RMSNorm + warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance") + +if is_flash_attn_available(): + from flash_attn.ops.activations import swiglu +else: + from .components import swiglu + warnings.warn("Cannot import flash_attn, install flash_attn to use fused SwiGLU for better performance") + +# try: +# from flash_attn.ops.activations import swiglu as fused_swiglu +# FUSEDSWIGLU_AVALIBLE = True +# except ImportError: + +# FUSEDSWIGLU_AVALIBLE = False +# warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, + embedding_dim: int, + norm_eps: float, + norm_elementwise_affine: bool, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + + self.norm = RMSNorm(embedding_dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + return x, gate_msa, scale_mlp, gate_mlp + + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + self.linear_2 = None + if out_dim is not None: + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + self.swiglu = swiglu + + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + + def forward(self, x): + h1, h2 = self.linear_1(x), self.linear_3(x) + return self.linear_2(self.swiglu(h1, h2)) + + +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + text_feat_dim: int = 2048, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + timestep_scale: float = 1.0, + ) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale + ) + + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) + + self.caption_embedder = nn.Sequential( + RMSNorm(text_feat_dim, eps=norm_eps), + nn.Linear(text_feat_dim, hidden_size, bias=True), + ) + + self._initialize_weights() + + def _initialize_weights(self): + nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) + nn.init.zeros_(self.caption_embedder[1].bias) + + def forward( + self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).to(dtype=dtype) + time_embed = self.timestep_embedder(timestep_proj) + caption_embed = self.caption_embedder(text_hidden_states) + return time_embed, caption_embed \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/components.py b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/components.py new file mode 100644 index 00000000..5e654b8c --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/components.py @@ -0,0 +1,4 @@ +import torch.nn.functional as F + +def swiglu(x, y): + return F.silu(x.float(), inplace=False).to(x.dtype) * y \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/repo.py b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/repo.py new file mode 100644 index 00000000..ea565bf8 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/repo.py @@ -0,0 +1,135 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from einops import repeat +from diffusers.models.embeddings import get_1d_rotary_pos_embed + +class OmniGen2RotaryPosEmbed(nn.Module): + def __init__(self, theta: int, + axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int] = (300, 512, 512), + patch_size: int = 2): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + + @staticmethod + def get_freqs_cis(axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int], + theta: int) -> List[torch.Tensor]: + freqs_cis = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) + freqs_cis.append(emb) + return freqs_cis + + def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + + result = [] + for i in range(len(self.axes_dim)): + freqs = freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1).to(device) + + def forward( + self, + freqs_cis, + attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device + ): + batch_size = len(attention_mask) + p = self.patch_size + + encoder_seq_len = attention_mask.shape[1] + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + + seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] + + max_seq_len = int(max(seq_lengths)) + max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) + max_img_len = max(l_effective_img_len) + + # Create position IDs + position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) + + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + cap_seq_len = int(cap_seq_len) + seq_len = int(seq_len) + # add text position ids + position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") + + pe_shift = cap_seq_len + pe_shift_len = cap_seq_len + + if ref_img_sizes[i] is not None: + for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): + H, W = ref_img_size + ref_H_tokens, ref_W_tokens = H // p, W // p + assert ref_H_tokens * ref_W_tokens == ref_img_len + # add image position ids + + row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() + col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids + + pe_shift += max(ref_H_tokens, ref_W_tokens) + pe_shift_len += ref_img_len + + H, W = img_sizes[i] + H_tokens, W_tokens = H // p, W // p + assert H_tokens * W_tokens == l_effective_img_len[i] + + row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() + col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() + + assert pe_shift_len + l_effective_img_len[i] == seq_len + position_ids[i, pe_shift_len: seq_len, 0] = pe_shift + position_ids[i, pe_shift_len: seq_len, 1] = row_ids + position_ids[i, pe_shift_len: seq_len, 2] = col_ids + + # Get combined rotary embeddings + freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) + + # create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + ref_img_freqs_cis = torch.zeros( + batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + img_freqs_cis = torch.zeros( + batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + + for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): + cap_seq_len = int(cap_seq_len) + sum_ref_img_len = int(sum(ref_img_len)) + img_len = int(img_len) + seq_len = int(seq_len) + cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] + ref_img_freqs_cis[i, :sum_ref_img_len] = freqs_cis[i, cap_seq_len:cap_seq_len + sum_ref_img_len] + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum_ref_img_len:cap_seq_len + sum_ref_img_len + img_len] + + return ( + cap_freqs_cis, + ref_img_freqs_cis, + img_freqs_cis, + freqs_cis, + l_effective_cap_len, + seq_lengths, + ) \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/transformer_omnigen2.py b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/transformer_omnigen2.py new file mode 100644 index 00000000..8e7fef68 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/transformer_omnigen2.py @@ -0,0 +1,621 @@ +import warnings +import itertools +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from einops import rearrange + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin + +from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen, OmniGen2AttnProcessor +from .repo import OmniGen2RotaryPosEmbed +from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding + +from ...utils.import_utils import is_triton_available, is_flash_attn_available + +if is_triton_available(): + from ...ops.triton.layer_norm import RMSNorm +else: + from torch.nn import RMSNorm + +logger = logging.get_logger(__name__) + + +class OmniGen2TransformerBlock(nn.Module): + """ + Transformer block for OmniGen2 model. + + This block implements a transformer layer with: + - Multi-head attention with flash attention + - Feed-forward network with SwiGLU activation + - RMS normalization + - Optional modulation for conditional generation + + Args: + dim: Dimension of the input and output tensors + num_attention_heads: Number of attention heads + num_kv_heads: Number of key-value heads + multiple_of: Multiple of which the hidden dimension should be + ffn_dim_multiplier: Multiplier for the feed-forward network dimension + norm_eps: Epsilon value for normalization layers + modulation: Whether to use modulation for conditional generation + use_fused_rms_norm: Whether to use fused RMS normalization + use_fused_swiglu: Whether to use fused SwiGLU activation + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + """Initialize the transformer block.""" + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + try: + processor = OmniGen2AttnProcessorFlash2Varlen() + except ImportError: + processor = OmniGen2AttnProcessor() + + # Initialize attention layer + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=processor, + ) + + # Initialize feed-forward network + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier + ) + + # Initialize normalization layers + if modulation: + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=True + ) + else: + self.norm1 = RMSNorm(dim, eps=norm_eps) + + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.initialize_weights() + + def initialize_weights(self) -> None: + """ + Initialize the weights of the transformer block. + + Uses Xavier uniform initialization for linear layers and zero initialization for biases. + """ + nn.init.xavier_uniform_(self.attn.to_q.weight) + nn.init.xavier_uniform_(self.attn.to_k.weight) + nn.init.xavier_uniform_(self.attn.to_v.weight) + nn.init.xavier_uniform_(self.attn.to_out[0].weight) + + nn.init.xavier_uniform_(self.feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_3.weight) + + if self.modulation: + nn.init.zeros_(self.norm1.linear.weight) + nn.init.zeros_(self.norm1.linear.bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the transformer block. + + Args: + hidden_states: Input hidden states tensor + attention_mask: Attention mask tensor + image_rotary_emb: Rotary embeddings for image tokens + temb: Optional timestep embedding tensor + + Returns: + torch.Tensor: Output hidden states after transformer block processing + """ + import time + if self.modulation: + if temb is None: + raise ValueError("temb must be provided when modulation is enabled") + + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + + +class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + OmniGen2 Transformer 2D Model. + + A transformer-based diffusion model for image generation with: + - Patch-based image processing + - Rotary position embeddings + - Multi-head attention + - Conditional generation support + + Args: + patch_size: Size of image patches + in_channels: Number of input channels + out_channels: Number of output channels (defaults to in_channels) + hidden_size: Size of hidden layers + num_layers: Number of transformer layers + num_refiner_layers: Number of refiner layers + num_attention_heads: Number of attention heads + num_kv_heads: Number of key-value heads + multiple_of: Multiple of which the hidden dimension should be + ffn_dim_multiplier: Multiplier for feed-forward network dimension + norm_eps: Epsilon value for normalization layers + axes_dim_rope: Dimensions for rotary position embeddings + axes_lens: Lengths for rotary position embeddings + text_feat_dim: Dimension of text features + timestep_scale: Scale factor for timestep embeddings + use_fused_rms_norm: Whether to use fused RMS normalization + use_fused_swiglu: Whether to use fused SwiGLU activation + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Omnigen2TransformerBlock"] + _skip_layerwise_casting_patterns = ["x_embedder", "norm"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 2304, + num_layers: int = 26, + num_refiner_layers: int = 2, + num_attention_heads: int = 24, + num_kv_heads: int = 8, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), + axes_lens: Tuple[int, int, int] = (300, 512, 512), + text_feat_dim: int = 1024, + timestep_scale: float = 1.0 + ) -> None: + """Initialize the OmniGen2 transformer model.""" + super().__init__() + + # Validate configuration + if (hidden_size // num_attention_heads) != sum(axes_dim_rope): + raise ValueError( + f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " + f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" + ) + + self.out_channels = out_channels or in_channels + + # Initialize embeddings + self.rope_embedder = OmniGen2RotaryPosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.ref_image_patch_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + text_feat_dim=text_feat_dim, + norm_eps=norm_eps, + timestep_scale=timestep_scale + ) + + # Initialize transformer blocks + self.noise_refiner = nn.ModuleList([ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_refiner_layers) + ]) + + self.ref_image_refiner = nn.ModuleList([ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_refiner_layers) + ]) + + self.context_refiner = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False + ) + for _ in range(num_refiner_layers) + ] + ) + + # 3. Transformer blocks + self.layers = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels + ) + + # Add learnable embeddings to distinguish different images + self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images + + self.gradient_checkpointing = False + + self.initialize_weights() + + def initialize_weights(self) -> None: + """ + Initialize the weights of the model. + + Uses Xavier uniform initialization for linear layers. + """ + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight) + nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0) + + nn.init.zeros_(self.norm_out.linear_1.weight) + nn.init.zeros_(self.norm_out.linear_1.bias) + nn.init.zeros_(self.norm_out.linear_2.weight) + nn.init.zeros_(self.norm_out.linear_2.bias) + + nn.init.normal_(self.image_index_embedding, std=0.02) + + def img_patch_embed_and_refine( + self, + hidden_states, + ref_image_hidden_states, + padded_img_mask, + padded_ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb + ): + batch_size = len(hidden_states) + max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)]) + + hidden_states = self.x_embedder(hidden_states) + ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) + + for i in range(batch_size): + shift = 0 + for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): + ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j] + shift += ref_img_len + + for layer in self.noise_refiner: + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + + flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len)) + num_ref_images = len(flat_l_effective_ref_img_len) + max_ref_img_len = max(flat_l_effective_ref_img_len) + + batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool) + batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size) + batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype) + batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype) + + # sequence of ref imgs to batch + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + batch_ref_img_mask[idx, :ref_img_len] = True + batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len] + batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len] + batch_temb[idx] = temb[i] + shift += ref_img_len + idx += 1 + + # refine ref imgs separately + for layer in self.ref_image_refiner: + batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb) + + # batch of ref imgs to sequence + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len] + shift += ref_img_len + idx += 1 + + combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size) + for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)): + combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)] + combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len] + + return combined_img_hidden_states + + def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): + batch_size = len(hidden_states) + p = self.config.patch_size + device = hidden_states[0].device + + img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] + l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] + + if ref_image_hidden_states is not None: + ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states] + l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes] + else: + ref_img_sizes = [None for _ in range(batch_size)] + l_effective_ref_img_len = [[0] for _ in range(batch_size)] + + max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) + max_img_len = max(l_effective_img_len) + + # ref image patch embeddings + flat_ref_img_hidden_states = [] + for i in range(batch_size): + if ref_img_sizes[i] is not None: + imgs = [] + for ref_img in ref_image_hidden_states[i]: + C, H, W = ref_img.size() + ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) + imgs.append(ref_img) + + img = torch.cat(imgs, dim=0) + flat_ref_img_hidden_states.append(img) + else: + flat_ref_img_hidden_states.append(None) + + # image patch embeddings + flat_hidden_states = [] + for i in range(batch_size): + img = hidden_states[i] + C, H, W = img.size() + + img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) + flat_hidden_states.append(img) + + padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) + padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + if ref_img_sizes[i] is not None: + padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i] + padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True + + padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) + padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i] + padded_img_mask[i, :l_effective_img_len[i]] = True + + return ( + padded_hidden_states, + padded_ref_img_hidden_states, + padded_img_mask, + padded_ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) + + def forward( + self, + hidden_states: Union[torch.Tensor, List[torch.Tensor]], + timestep: torch.Tensor, + text_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + text_attention_mask: torch.Tensor, + ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # 1. Condition, positional & patch embedding + batch_size = len(hidden_states) + is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) + + if is_hidden_states_tensor: + assert hidden_states.ndim == 4 + hidden_states = [_hidden_states for _hidden_states in hidden_states] + + device = hidden_states[0].device + + temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) + + ( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) + + ( + context_rotary_emb, + ref_img_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = self.rope_embedder( + freqs_cis, + text_attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ) + + # 2. Context refinement + for layer in self.context_refiner: + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + + combined_img_hidden_states = self.img_patch_embed_and_refine( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ) + + # 3. Joint Transformer blocks + max_seq_len = int(max(seq_lengths)) + + attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + encoder_seq_len = int(encoder_seq_len) + seq_len = int(seq_len) + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len] + joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len] + + hidden_states = joint_hidden_states + + for layer_idx, layer in enumerate(self.layers): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer, hidden_states, attention_mask, rotary_emb, temb + ) + else: + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + + p = self.config.patch_size + output = [] + for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)): + img_len = int(img_len) + seq_len = int(seq_len) + height, width = img_size + output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p)) + if is_hidden_states_tensor: + output = torch.stack(output, dim=0) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return output + return Transformer2DModelOutput(sample=output) diff --git a/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/layer_norm.py b/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/layer_norm.py new file mode 100644 index 00000000..b7d12330 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/layer_norm.py @@ -0,0 +1,1257 @@ +# Copyright (c) 2024, Tri Dao. +# Implement dropout + residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + + +from typing import Callable + + +def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) + return decorator + + +if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] + deprecated = True + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] +else: + deprecated = False + from torch.cuda.amp import custom_fwd, custom_bwd + +custom_fwd = custom_amp_decorator(custom_fwd, deprecated) +custom_bwd = custom_amp_decorator(custom_bwd, deprecated) + + +def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs=[] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block=1024 + # Default to warp size 32 if not defined by device + warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + warp_count=1 + while warp_count*warp_size <= max_threads_per_block: + configs.append(triton.Config({}, num_warps=warp_count)) + warp_count*=2 + return configs + +def layer_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = F.layer_norm( + x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps + ).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) + + +def rms_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( + dtype + ) + return (out, out1) if not prenorm else (out, out1, x) + + +@triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + residual=None, + x1=None, + weight1=None, + bias1=None, + dropout_p=0.0, + rowscale=None, + out_dtype=None, + residual_dtype=None, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + else: + assert out.shape == x.shape + assert out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + if ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + if residual_out is None: + residual_out = torch.empty( + M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask = None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + out, + weight, + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if dropout_mask is not None and x1 is not None: + dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) + else: + dropout_mask1 = None + return ( + out, + y1, + mean, + rstd, + residual_out if residual_out is not None else x, + seeds, + dropout_mask, + dropout_mask1, + ) + + +@triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) +@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) +@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + W1, + DY1, + DX1, + DW1, + DB1, + DRESIDUAL_IN, + ROWSCALE, + SEEDS, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dy1_row, + stride_dx1_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, + zero_centered_weight, + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_DY1: tl.constexpr, + HAS_DX1: tl.constexpr, + HAS_B1: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + # Do not early exit if row_start >= M, because we need to write DW and DB + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if HAS_DY1: + DY1 += row_start * stride_dy1_row + if HAS_DX1: + DX1 += row_start * stride_dx1_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_DY1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_DY1: + dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_B1: + db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if HAS_DY1: + dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_DY1: + wdy += w1 * dy1 + dw1 += dy1 * xhat + if HAS_B1: + db1 += dy1 + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DX1: + if HAS_DROPOUT: + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + else: + dx1 = dx + tl.store(DX1 + cols, dx1, mask=mask) + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + dx *= rowscale + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_DY1: + DY1 += stride_dy1_row + if HAS_DX1: + DX1 += stride_dx1_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + if HAS_DY1: + tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) + if HAS_B1: + tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + dy1=None, + weight1=None, + bias1=None, + seeds=None, + dropout_p=0.0, + rowscale=None, + has_residual=False, + has_x1=False, + zero_centered_weight=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if dy1 is not None: + assert weight1 is not None + assert dy1.shape == dy.shape + assert dy1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M if not has_x1 else M * 2,) + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = ( + torch.empty_like(x) + if has_residual + and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) + else None + ) + dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + if recompute_output: + assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the + # latency of the gmem reads/writes, but will increase the time of summing up dw / db. + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + _dw1 = torch.empty_like(_dw) if weight1 is not None else None + _db1 = torch.empty_like(_db) if bias1 is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + weight1, + dy1, + dx1, + _dw1, + _db1, + dresidual_in, + rowscale, + seeds, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dy1.stride(0) if dy1 is not None else 0, + dx1.stride(0) if dx1 is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + dropout_p > 0.0, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None + db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return ( + (dx, dw, db, dresidual_in, dx1, dw1, db1) + if not recompute_output + else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) + ) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None + ): + x_shape_og = x.shape + # Check for zero sequence length + if x.numel() == 0: + ctx.zero_seq_length = True + # Only save minimal required tensors for backward + # ctx.save_for_backward(weight, bias, weight1, bias1) + ctx.x_shape_og = x_shape_og + ctx.weight_shape = weight.shape + ctx.weight_dtype = weight.dtype + ctx.weight_device = weight.device + + ctx.has_bias = bias is not None + ctx.bias_shape = bias.shape if bias is not None else None + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.bias_device = bias.device if bias is not None else None + + ctx.has_weight1 = weight1 is not None + ctx.weight1_shape = weight1.shape if weight1 is not None else None + ctx.weight1_dtype = weight1.dtype if weight1 is not None else None + ctx.weight1_device = weight1.device if weight1 is not None else None + + ctx.has_bias1 = bias1 is not None + ctx.bias1_shape = bias1.shape if bias1 is not None else None + ctx.bias1_dtype = bias1.dtype if bias1 is not None else None + ctx.bias1_device = bias1.device if bias1 is not None else None + + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.dropout_p = dropout_p + + # Handle output tensors with correct dtype + y = x # Preserve input tensor properties + y1 = torch.empty_like(x) if x1 is not None else None + + # Only create residual_out if prenorm is True + residual_out = torch.empty(x.shape, + dtype=torch.float32 if residual_in_fp32 else x.dtype, + device=x.device) if prenorm else None + + # Handle dropout masks + dropout_mask = None + dropout_mask1 = None + if return_dropout_mask: + dropout_mask = torch.empty_like(x, dtype=torch.uint8) + if x1 is not None: + dropout_mask1 = torch.empty_like(x, dtype=torch.uint8) + + # Return based on configuration + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ((y, dropout_mask, dropout_mask1) if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1)) + else: + return ((y, y1, dropout_mask, dropout_mask1) if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1)) + + ctx.zero_seq_length = False + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = x1.reshape(-1, x1.shape[-1]) + if x1.stride(-1) != 1: + x1 = x1.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + if weight1 is not None: + weight1 = weight1.contiguous() + if bias1 is not None: + bias1 = bias1.contiguous() + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out + ) + ctx.save_for_backward( + residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd + ) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.dropout_p = dropout_p + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.zero_centered_weight = zero_centered_weight + y = y.reshape(x_shape_og) + y1 = y1.reshape(x_shape_og) if y1 is not None else None + residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None + dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) + if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) + + @staticmethod + def backward(ctx, dy, *args): + if ctx.zero_seq_length: + return ( + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device), + torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device), + torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None, + torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None, + torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if weight1 is not None: + dy1, args = args[0], args[1:] + dy1 = dy1.reshape(-1, dy1.shape[-1]) + if dy1.stride(-1) != 1: + dy1 = dy1.contiguous() + assert dy1.shape == x.shape + else: + dy1 = None + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + + dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + ctx.dropout_p, + rowscale, + ctx.has_residual, + ctx.has_x1, + ctx.zero_centered_weight, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, + dw1, + db1, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out, + residual_out + ) + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out, + residual_out + ) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, + device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if dropout_p > 0.0: + self.drop = torch.nn.Dropout(dropout_p) + else: + self.drop = None + self.zero_centered_weight = zero_centered_weight + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + if not self.zero_centered_weight: + torch.nn.init.ones_(self.weight) + else: + torch.nn.init.zeros_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + zero_centered_weight=self.zero_centered_weight, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/image_processor.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/image_processor.py new file mode 100644 index 00000000..ec66dcdb --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/image_processor.py @@ -0,0 +1,266 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist +from diffusers.configuration_utils import register_to_config + +class OmniGen2ImageProcessor(VaeImageProcessor): + """ + Image processor for PixArt image resize and crop. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + resample: str = "lanczos", + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_grayscale: bool = False, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + resample=resample, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_grayscale=do_convert_grayscale, + ) + + self.max_pixels = max_pixels + self.max_side_length = max_side_length + + def get_new_height_width( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + ) -> Tuple[int, int]: + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor`. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + if max_side_length is None: + max_side_length = self.max_side_length + + if max_pixels is None: + max_pixels = self.max_pixels + + ratio = 1.0 + if max_side_length is not None: + if height > width: + max_side_length_ratio = max_side_length / height + else: + max_side_length_ratio = max_side_length / width + + cur_pixels = height * width + max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image + + new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor + return new_height, new_width + + def preprocess( + self, + image: PipelineImageInput, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + resize_mode: str = "default", # "default", "fill", "crop" + crops_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> torch.Tensor: + """ + Preprocess the image input. + + Args: + image (`PipelineImageInput`): + The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of + supported formats. + height (`int`, *optional*): + The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default + height. + width (`int`, *optional*): + The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within + the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will + resize the image to fit within the specified width and height, maintaining the aspect ratio, and then + center the image within the dimensions, filling empty with data from image. If `crop`, will resize the + image to fit within the specified width and height, maintaining the aspect ratio, and then center the + image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. + crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): + The crop coordinates for each image in the batch. If `None`, will not crop the image. + + Returns: + `torch.Tensor`: + The preprocessed image. + """ + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + + # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image + if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: + if isinstance(image, torch.Tensor): + # if image is a pytorch tensor could have 2 possible shapes: + # 1. batch x height x width: we should insert the channel dimension at position 1 + # 2. channel x height x width: we should insert batch dimension at position 0, + # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 + # for simplicity, we insert a dimension of size 1 at position 1 for both cases + image = image.unsqueeze(1) + else: + # if it is a numpy array, it could have 2 possible shapes: + # 1. batch x height x width: insert channel dimension on last position + # 2. height x width x channel: insert batch dimension on first position + if image.shape[-1] == 1: + image = np.expand_dims(image, axis=0) + else: + image = np.expand_dims(image, axis=-1) + + if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", + FutureWarning, + ) + image = np.concatenate(image, axis=0) + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", + FutureWarning, + ) + image = torch.cat(image, axis=0) + + if not is_valid_image_imagelist(image): + raise ValueError( + f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}" + ) + if not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + if crops_coords is not None: + image = [i.crop(crops_coords) for i in image] + if self.config.do_resize: + height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length) + image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image] + if self.config.do_convert_rgb: + image = [self.convert_to_rgb(i) for i in image] + elif self.config.do_convert_grayscale: + image = [self.convert_to_grayscale(i) for i in image] + image = self.pil_to_numpy(image) # to np + image = self.numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + + image = self.numpy_to_pt(image) + + height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length) + if self.config.do_resize: + image = self.resize(image, height, width) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + if self.config.do_convert_grayscale and image.ndim == 3: + image = image.unsqueeze(1) + + channel = image.shape[1] + # don't need any preprocess if the image is latents + if channel == self.config.vae_latent_channels: + return image + + height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length) + if self.config.do_resize: + image = self.resize(image, height, width) + + # expected range [0,1], normalize to [-1,1] + do_normalize = self.config.do_normalize + if do_normalize and image.min() < 0: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", + FutureWarning, + ) + do_normalize = False + if do_normalize: + image = self.normalize(image) + + if self.config.do_binarize: + image = self.binarize(image) + + return image \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py new file mode 100644 index 00000000..45e509bc --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py @@ -0,0 +1,728 @@ +""" +OmniGen2 Diffusion Pipeline + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import math + +from PIL import Image +import numpy as np +import torch +import torch.nn.functional as F + +from transformers import Qwen2_5_VLForConditionalGeneration + +from diffusers.models.autoencoders import AutoencoderKL +from ...models.transformers import OmniGen2Transformer2DModel +from ...models.transformers.repo import OmniGen2RotaryPosEmbed +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + is_torch_xla_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from dataclasses import dataclass + +import PIL.Image + +from diffusers.utils import BaseOutput + +from ....src.pipelines.image_processor import OmniGen2ImageProcessor + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@dataclass +class FMPipelineOutput(BaseOutput): + """ + Output class for OmniGen2 pipeline. + + Args: + images (Union[List[PIL.Image.Image], np.ndarray]): + List of denoised PIL images of length `batch_size` or numpy array of shape + `(batch_size, height, width, num_channels)`. Contains the generated images. + """ + images: Union[List[PIL.Image.Image], np.ndarray] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OmniGen2Pipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using OmniGen2. + + This pipeline implements a text-to-image generation model that uses: + - Qwen2.5-VL for text encoding + - A custom transformer architecture for image generation + - VAE for image encoding/decoding + - FlowMatchEulerDiscreteScheduler for noise scheduling + + Args: + transformer (OmniGen2Transformer2DModel): The transformer model for image generation. + vae (AutoencoderKL): The VAE model for image encoding/decoding. + scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling. + text_encoder (Qwen2_5_VLModel): The text encoder model. + tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing. + """ + + model_cpu_offload_seq = "mllm->transformer->vae" + + def __init__( + self, + transformer: OmniGen2Transformer2DModel, + vae: AutoencoderKL, + scheduler: FlowMatchEulerDiscreteScheduler, + mllm: Qwen2_5_VLForConditionalGeneration, + processor, + ) -> None: + """ + Initialize the OmniGen2 pipeline. + + Args: + transformer: The transformer model for image generation. + vae: The VAE model for image encoding/decoding. + scheduler: The scheduler for noise scheduling. + text_encoder: The text encoder model. + tokenizer: The tokenizer for text processing. + """ + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + scheduler=scheduler, + mllm=mllm, + processor=processor + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True) + self.default_sample_size = 128 + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator], + latents: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Prepare the initial latents for the diffusion process. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of channels in the latent space. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type of the latents. + device: The device to place the latents on. + generator: The random number generator to use. + latents: Optional pre-computed latents to use instead of random initialization. + + Returns: + torch.FloatTensor: The prepared latents tensor. + """ + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor: + """ + Encode an image into the VAE latent space. + + Args: + img: The input image tensor to encode. + + Returns: + torch.FloatTensor: The encoded latent representation. + """ + z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample() + if self.vae.config.shift_factor is not None: + z0 = z0 - self.vae.config.shift_factor + if self.vae.config.scaling_factor is not None: + z0 = z0 * self.vae.config.scaling_factor + z0 = z0.to(dtype=self.vae.dtype) + return z0 + + def prepare_image( + self, + images: Union[List[PIL.Image.Image], PIL.Image.Image], + batch_size: int, + num_images_per_prompt: int, + max_pixels: int, + max_side_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> List[Optional[torch.FloatTensor]]: + """ + Prepare input images for processing by encoding them into the VAE latent space. + + Args: + images: Single image or list of images to process. + batch_size: The number of images to generate per prompt. + num_images_per_prompt: The number of images to generate for each prompt. + device: The device to place the encoded latents on. + dtype: The data type of the encoded latents. + + Returns: + List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image. + """ + if batch_size == 1: + images = [images] + latents = [] + for i, img in enumerate(images): + if img is not None and len(img) > 0: + ref_latents = [] + for j, img_j in enumerate(img): + img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length) + ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0)) + else: + ref_latents = None + for _ in range(num_images_per_prompt): + latents.append(ref_latents) + + return latents + + def _get_qwen2_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get prompt embeddings from the Qwen2 text encoder. + + Args: + prompt: The prompt or list of prompts to encode. + device: The device to place the embeddings on. If None, uses the pipeline's device. + max_sequence_length: Maximum sequence length for tokenization. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The prompt embeddings tensor + - The attention mask tensor + + Raises: + Warning: If the input text is truncated due to sequence length limitations. + """ + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + # text_inputs = self.processor.tokenizer( + # prompt, + # padding="max_length", + # max_length=max_sequence_length, + # truncation=True, + # return_tensors="pt", + # ) + text_inputs = self.processor.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because Gemma can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.mllm( + text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-1] + + if self.mllm is not None: + dtype = self.mllm.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def _apply_chat_template(self, prompt: str): + prompt = [ + { + "role": "system", + "content": "You are a helpful assistant that generates high-quality images based on user instructions.", + }, + {"role": "user", "content": prompt}, + ] + prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False) + return prompt + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + + if prompt is not None: + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [self._apply_chat_template(_prompt) for _prompt in prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length + ) + + batch_size, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt] + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=negative_prompt, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = negative_prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def text_guidance_scale(self): + return self._text_guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + def cfg_range(self): + return self._cfg_range + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.LongTensor] = None, + negative_prompt_attention_mask: Optional[torch.LongTensor] = None, + max_sequence_length: Optional[int] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + input_images: Optional[List[PIL.Image.Image]] = None, + num_images_per_prompt: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: int = 1024 * 1024, + max_input_image_side_length: int = 1024, + align_res: bool = True, + num_inference_steps: int = 28, + text_guidance_scale: float = 4.0, + image_guidance_scale: float = 1.0, + cfg_range: Tuple[float, float] = (0.0, 1.0), + attention_kwargs: Optional[Dict[str, Any]] = None, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + verbose: bool = False, + step_func=None, + ): + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._text_guidance_scale = text_guidance_scale + self._image_guidance_scale = image_guidance_scale + self._cfg_range = cfg_range + self._attention_kwargs = attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.text_guidance_scale > 1.0, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + ) + + dtype = self.vae.dtype + # 3. Prepare control image + ref_latents = self.prepare_image( + images=input_images, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + max_pixels=max_pixels, + max_side_length=max_input_image_side_length, + device=device, + dtype=dtype, + ) + + if input_images is None: + input_images = [] + + if len(input_images) == 1 and align_res: + width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor + ori_width, ori_height = width, height + else: + ori_width, ori_height = width, height + + cur_pixels = height * width + ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(ratio, 1.0) + + height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16 + + if len(input_images) == 0: + self._image_guidance_scale = 1 + + # 4. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( + self.transformer.config.axes_dim_rope, + self.transformer.config.axes_lens, + theta=10000, + ) + + image = self.processing( + latents=latents, + ref_latents=ref_latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + device=device, + dtype=dtype, + verbose=verbose, + step_func=step_func, + ) + + image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear') + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return image + else: + return FMPipelineOutput(images=image) + + def processing( + self, + latents, + ref_latents, + prompt_embeds, + freqs_cis, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + num_inference_steps, + timesteps, + device, + dtype, + verbose, + step_func=None + ): + batch_size = latents.shape[0] + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + num_tokens=latents.shape[-2] * latents.shape[-1] + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_pred = self.predict( + t=t, + latents=latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + + if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: + model_pred_ref = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + if image_guidance_scale != 1: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + else: + model_pred_uncond = torch.zeros_like(model_pred) + + model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \ + text_guidance_scale * (model_pred - model_pred_ref) + elif text_guidance_scale > 1.0: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) + + latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] + + latents = latents.to(dtype=dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if step_func is not None: + step_func(i, self._num_timesteps) + + latents = latents.to(dtype=dtype) + if self.vae.config.scaling_factor is not None: + latents = latents / self.vae.config.scaling_factor + if self.vae.config.shift_factor is not None: + latents = latents + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + return image + + def predict( + self, + t, + latents, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + ref_image_hidden_states, + ): + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + batch_size, num_channels_latents, height, width = latents.shape + + optional_kwargs = {} + if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()): + optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states + + model_pred = self.transformer( + latents, + timestep, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + **optional_kwargs + ) + return model_pred \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2_chat.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2_chat.py new file mode 100644 index 00000000..43d88402 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2_chat.py @@ -0,0 +1,830 @@ +""" +OmniGen2 Diffusion Pipeline + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import math + +from PIL import Image +import numpy as np +import torch +import torch.nn.functional as F + +from transformers import Qwen2_5_VLForConditionalGeneration + +from diffusers.models.autoencoders import AutoencoderKL +from ...models.transformers import OmniGen2Transformer2DModel +from ...models.transformers.repo import OmniGen2RotaryPosEmbed +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + is_torch_xla_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from dataclasses import dataclass + +import PIL.Image + +from diffusers.utils import BaseOutput + +from src.pipelines.image_processor import OmniGen2ImageProcessor + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@dataclass +class OmniGen2PipelineOutput(BaseOutput): + """ + Output class for OmniGen2 pipeline. + + Args: + images (Union[List[PIL.Image.Image], np.ndarray]): + List of denoised PIL images of length `batch_size` or numpy array of shape + `(batch_size, height, width, num_channels)`. Contains the generated images. + """ + text: str + images: Union[List[PIL.Image.Image], np.ndarray] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OmniGen2ChatPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using OmniGen2. + + This pipeline implements a text-to-image generation model that uses: + - Qwen2.5-VL for text encoding + - A custom transformer architecture for image generation + - VAE for image encoding/decoding + - FlowMatchEulerDiscreteScheduler for noise scheduling + + Args: + transformer (OmniGen2Transformer2DModel): The transformer model for image generation. + vae (AutoencoderKL): The VAE model for image encoding/decoding. + scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling. + text_encoder (Qwen2_5_VLModel): The text encoder model. + tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing. + """ + + model_cpu_offload_seq = "mllm->transformer->vae" + def __init__( + self, + transformer: OmniGen2Transformer2DModel, + vae: AutoencoderKL, + scheduler: FlowMatchEulerDiscreteScheduler, + mllm: Qwen2_5_VLForConditionalGeneration, + processor, + ) -> None: + """ + Initialize the OmniGen2 pipeline. + + Args: + transformer: The transformer model for image generation. + vae: The VAE model for image encoding/decoding. + scheduler: The scheduler for noise scheduling. + text_encoder: The text encoder model. + tokenizer: The tokenizer for text processing. + """ + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + scheduler=scheduler, + mllm=mllm, + processor=processor + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True) + self.default_sample_size = 128 + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator], + latents: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Prepare the initial latents for the diffusion process. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of channels in the latent space. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type of the latents. + device: The device to place the latents on. + generator: The random number generator to use. + latents: Optional pre-computed latents to use instead of random initialization. + + Returns: + torch.FloatTensor: The prepared latents tensor. + """ + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor: + """ + Encode an image into the VAE latent space. + + Args: + img: The input image tensor to encode. + + Returns: + torch.FloatTensor: The encoded latent representation. + """ + z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample() + if self.vae.config.shift_factor is not None: + z0 = z0 - self.vae.config.shift_factor + if self.vae.config.scaling_factor is not None: + z0 = z0 * self.vae.config.scaling_factor + z0 = z0.to(dtype=self.vae.dtype) + return z0 + + def prepare_image( + self, + images: Union[List[PIL.Image.Image], PIL.Image.Image], + batch_size: int, + num_images_per_prompt: int, + max_pixels: int, + max_side_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> List[Optional[torch.FloatTensor]]: + """ + Prepare input images for processing by encoding them into the VAE latent space. + + Args: + images: Single image or list of images to process. + batch_size: The number of images to generate per prompt. + num_images_per_prompt: The number of images to generate for each prompt. + device: The device to place the encoded latents on. + dtype: The data type of the encoded latents. + + Returns: + List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image. + """ + if batch_size == 1: + images = [images] + latents = [] + for i, img in enumerate(images): + if img is not None and len(img) > 0: + ref_latents = [] + for j, img_j in enumerate(img): + img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length) + ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0)) + else: + ref_latents = None + for _ in range(num_images_per_prompt): + latents.append(ref_latents) + + return latents + + def _apply_chat_template(self, prompt: str, images: List = None): + if images is not None: + prompt = "".join( + [ + f": <|vision_start|><|image_pad|><|vision_end|>" + for i in range(1, len(images) + 1) + ] + ) + prompt + prompt = f"<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + return prompt + + def _get_qwen2_prompt_embeds( + self, + prompt: Union[str, List[str]], + input_images = None, + device: Optional[torch.device] = None, + use_only_text_hidden_states: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get prompt embeddings from the Qwen2 text encoder. + + Args: + prompt: The prompt or list of prompts to encode. + device: The device to place the embeddings on. If None, uses the pipeline's device. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The prompt embeddings tensor + - The attention mask tensor + + Raises: + Warning: If the input text is truncated due to sequence length limitations. + """ + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + + inputs = self.processor( + text=prompt, + images=input_images, + videos=None, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(device) + + prompt_embeds = self.mllm( + **inputs, + output_hidden_states=True, + ).hidden_states[-1] + + text_input_ids = inputs.input_ids + text_mask = inputs.attention_mask + if use_only_text_hidden_states: + mask = text_input_ids != self.mllm.config.image_token_id + mask = mask & text_mask + mask = mask.bool() + + text_l = mask.sum(dim=-1) + max_l = text_l.max() + text_batch_size = prompt_embeds.size(0) + new_prompt_embeds = torch.zeros((text_batch_size, max_l, prompt_embeds.size(-1)), device=prompt_embeds.device, dtype=prompt_embeds.dtype) + new_text_mask = torch.zeros((text_batch_size, max_l), dtype=text_mask.dtype, device=text_mask.device) + for i in range(text_batch_size): + new_prompt_embeds[i, :text_l[i]] = prompt_embeds[i, mask[i]] + new_text_mask[i, :text_l[i]] = 1 + + prompt_embeds = new_prompt_embeds + text_mask = new_text_mask + + prompt_embeds = prompt_embeds.to(dtype=self.mllm.dtype, device=device) + return prompt_embeds, text_mask + + + def encode_prompt( + self, + prompt: Union[str, List[str]], + input_images: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + use_text_encoder_penultimate_layer_feats: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=prompt, + input_images=input_images, + device=device, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + # Get negative embeddings for classifier free guidance + negative_prompt_embeds, negative_prompt_attention_mask = None, None + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt] + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=negative_prompt, + device=device, + ) + + batch_size, seq_len, _ = negative_prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def text_guidance_scale(self): + return self._text_guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + def cfg_range(self): + return self._cfg_range + + def prepare_inputs_for_text_generation(self, prompts, input_images, device): + if isinstance(prompts, str): + prompts = [prompts] + + ori_padding_side = self.processor.tokenizer.padding_side + self.processor.tokenizer.padding_side = "left" + inputs = self.processor( + text=prompts, + images=input_images, + videos=None, + padding=True, + return_tensors="pt", + ).to(device) + self.processor.tokenizer.padding_side = ori_padding_side + return inputs + + def generate_text(self, prompt, input_images): + inputs = self.prepare_inputs_for_text_generation( + prompt, input_images, self.mllm.device + ) + generated_ids = self.mllm.generate( + **inputs, + tokenizer=self.processor.tokenizer, + max_new_tokens=256, + stop_strings=["<|im_end|>", "<|img|>", "<|endoftext|>"], + ) # stop_words=[151643, 151645, 151665] + generated_ids_trimmed = [ + out_ids[len(in_ids) :] + for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_texts = self.processor.batch_decode( + generated_ids_trimmed, + # skip_special_tokens=True, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + return output_texts + + def generate_image( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.LongTensor] = None, + negative_prompt_attention_mask: Optional[torch.LongTensor] = None, + use_text_encoder_penultimate_layer_feats: bool = False, + max_sequence_length: Optional[int] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + input_images: Optional[List[PIL.Image.Image]] = None, + num_images_per_prompt: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: int = 1024 * 1024, + max_input_image_side_length: int = 1024, + align_res: bool = True, + num_inference_steps: int = 28, + text_guidance_scale: float = 4.0, + image_guidance_scale: float = 1.0, + cfg_range: Tuple[float, float] = (0.0, 1.0), + attention_kwargs: Optional[Dict[str, Any]] = None, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + verbose: bool = False, + step_func=None, + ): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._text_guidance_scale = text_guidance_scale + self._image_guidance_scale = image_guidance_scale + self._cfg_range = cfg_range + self._attention_kwargs = attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input promptb + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + input_images, + self.text_guidance_scale > 1.0, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats + ) + + dtype = self.vae.dtype + # 3. Prepare control image + ref_latents = self.prepare_image( + images=input_images, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + max_pixels=max_pixels, + max_side_length=max_input_image_side_length, + device=device, + dtype=dtype, + ) + + if input_images is None: + input_images = [] + + if len(input_images) == 1 and align_res: + width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor + ori_width, ori_height = width, height + else: + ori_width, ori_height = width, height + + cur_pixels = height * width + ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(ratio, 1.0) + + height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16 + + if len(input_images) == 0: + self._image_guidance_scale = 1 + + # 4. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( + self.transformer.config.axes_dim_rope, + self.transformer.config.axes_lens, + theta=10000, + ) + + image = self.processing( + latents=latents, + ref_latents=ref_latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + device=device, + dtype=dtype, + verbose=verbose, + step_func=step_func, + ) + + image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear') + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + return image + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.LongTensor] = None, + negative_prompt_attention_mask: Optional[torch.LongTensor] = None, + use_text_encoder_penultimate_layer_feats: bool = False, + max_sequence_length: Optional[int] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + input_images: Optional[List[PIL.Image.Image]] = None, + num_images_per_prompt: int = 1, + height: Optional[int] = 1024, + width: Optional[int] = 1024, + max_pixels: Optional[int] = 1024 * 1024, + max_input_image_side_length: int = 1024, + align_res: bool = True, + num_inference_steps: int = 28, + text_guidance_scale: float = 4.0, + image_guidance_scale: float = 1.0, + cfg_range: Tuple[float, float] = (0.0, 1.0), + attention_kwargs: Optional[Dict[str, Any]] = None, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + verbose: bool = False, + step_func=None, + ): + assert isinstance(prompt, str), "prompt must be a string since chat mode only support one prompt per turn" + + # input_images = self.preprocess_images(input_images, max_input_image_size) + prompt = self._apply_chat_template(prompt, input_images) + generated_text = self.generate_text(prompt, input_images)[0] + + images = None + if generated_text.startswith("<|img|>"): + #TODO: reuse the hidden state when generate text instead of re-generating + prompt = prompt + generated_text.split("<|img|>")[0] + images = self.generate_image( + prompt=prompt, + negative_prompt=negative_prompt, + use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats, + max_sequence_length=max_sequence_length, + input_images=input_images, + num_images_per_prompt=num_images_per_prompt, + height=height, + width=width, + max_pixels=max_pixels, + max_input_image_side_length=max_input_image_side_length, + align_res=align_res, + num_inference_steps=num_inference_steps, + text_guidance_scale=text_guidance_scale, + image_guidance_scale=image_guidance_scale, + cfg_range=cfg_range, + timesteps=timesteps, + generator=generator, + latents=latents, + return_dict=False, + verbose=verbose, + step_func=step_func, + ) + + generated_text = generated_text.replace("<|im_end|>", "") + if not return_dict: + return generated_text, images + else: + return OmniGen2PipelineOutput(text=generated_text, images=images) + + def processing( + self, + latents, + ref_latents, + prompt_embeds, + freqs_cis, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + num_inference_steps, + timesteps, + device, + dtype, + verbose, + step_func=None + ): + batch_size = latents.shape[0] + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + num_tokens=latents.shape[-2] * latents.shape[-1] + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_pred = self.predict( + t=t, + latents=latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: + model_pred_ref = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + if image_guidance_scale != 1: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + else: + model_pred_uncond = torch.zeros_like(model_pred) + + model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \ + text_guidance_scale * (model_pred - model_pred_ref) + elif text_guidance_scale > 1.0: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) + + latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] + + latents = latents.to(dtype=dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if step_func is not None: + step_func(i, self._num_timesteps) + + latents = latents.to(dtype=dtype) + if self.vae.config.scaling_factor is not None: + latents = latents / self.vae.config.scaling_factor + if self.vae.config.shift_factor is not None: + latents = latents + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + return image + + def predict( + self, + t, + latents, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + ref_image_hidden_states, + ): + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + batch_size, num_channels_latents, height, width = latents.shape + + optional_kwargs = {} + if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()): + optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states + + model_pred = self.transformer( + latents, + timestep, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + **optional_kwargs + ) + return model_pred diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/pipeline_utils.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/pipeline_utils.py new file mode 100644 index 00000000..de31ff4e --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/pipeline_utils.py @@ -0,0 +1,62 @@ +import torch + + +def get_pipeline_embeds(pipeline, prompt, negative_prompt, device): + """ Get pipeline embeds for prompts bigger than the maxlength of the pipe + :param pipeline: + :param prompt: + :param negative_prompt: + :param device: + :return: + """ + max_length = pipeline.tokenizer.model_max_length + + # simple way to determine length of tokens + # count_prompt = len(prompt.split(" ")) + # count_negative_prompt = len(negative_prompt.split(" ")) + + # create the tensor based on which prompt is longer + # if count_prompt >= count_negative_prompt: + input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device) + # input_ids = pipeline.tokenizer(prompt, padding="max_length", + # max_length=pipeline.tokenizer.model_max_length, + # truncation=True, + # return_tensors="pt",).input_ids.to(device) + shape_max_length = input_ids.shape[-1] + + if negative_prompt is not None: + negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length", + max_length=shape_max_length, return_tensors="pt").input_ids.to(device) + + # else: + # negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device) + # shape_max_length = negative_ids.shape[-1] + # input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length", + # max_length=shape_max_length).input_ids.to(device) + + concat_embeds = [] + neg_embeds = [] + for i in range(0, shape_max_length, max_length): + if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: + attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device) + else: + attention_mask = None + concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length], + attention_mask=attention_mask)[0]) + + if negative_prompt is not None: + if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: + attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device) + else: + attention_mask = None + neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length], + attention_mask=attention_mask)[0]) + + concat_embeds = torch.cat(concat_embeds, dim=1) + + if negative_prompt is not None: + neg_embeds = torch.cat(neg_embeds, dim=1) + else: + neg_embeds = None + + return concat_embeds, neg_embeds diff --git a/extensions_built_in/diffusion_models/omnigen2/src/schedulers/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/schedulers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_dpmsolver_multistep.py b/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_dpmsolver_multistep.py new file mode 100644 index 00000000..1c85e930 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_dpmsolver_multistep.py @@ -0,0 +1,1052 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +if is_scipy_available(): + import scipy.stats + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. + use_lu_lambdas (`bool`, *optional*, defaults to `False`): + Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during + the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of + `lambda(t)`. + use_flow_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. + flow_shift (`float`, *optional*, defaults to 1.0): + The shift value for the timestep schedule for flow matching. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: str = 'zero', + dynamic_time_shift: bool = True + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + # if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": + # raise ValueError( + # f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + # ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + num_tokens: Optional[int] = None + ): + if timesteps is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1] + if self.config.dynamic_time_shift and num_tokens is not None: + m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2 + timesteps = timesteps / (m - m * timesteps + timesteps) + + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) + sigmas = torch.cat([1 - timesteps, torch.zeros(1, device=timesteps.device)]) + + self.sigmas = sigmas + self.timesteps = timesteps + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 - sigma + sigma_t = sigma + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Lu et al. (2022).""" + + lambda_min: float = in_lambdas[-1].item() + lambda_max: float = in_lambdas[0].item() + + rho = 1.0 # 1.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = lambda_min ** (1 / rho) + max_inv_rho = lambda_max ** (1 / rho) + lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return lambdas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample + sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 + ) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to(device=model_output.device, dtype=torch.float32) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_flow_match_euler_discrete.py b/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_flow_match_euler_discrete.py new file mode 100644 index 00000000..91513476 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,229 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + dynamic_time_shift: bool = True + ): + timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1] + + self.timesteps = timesteps + + self._step_index = None + self._begin_index = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self._timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[float]] = None, + num_tokens: Optional[int] = None + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if timesteps is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1] + if self.config.dynamic_time_shift and num_tokens is not None: + m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2 + timesteps = timesteps / (m - m * timesteps + timesteps) + + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) + _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)]) + + self.timesteps = timesteps + self._timesteps = _timesteps + self._step_index = None + self._begin_index = None + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + t = self._timesteps[self.step_index] + t_next = self._timesteps[self.step_index + 1] + + prev_sample = sample + (t_next - t) * model_output + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/utils/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions_built_in/diffusion_models/omnigen2/src/utils/img_util.py b/extensions_built_in/diffusion_models/omnigen2/src/utils/img_util.py new file mode 100644 index 00000000..d9026a63 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/utils/img_util.py @@ -0,0 +1,31 @@ +from typing import List + +from PIL import Image + +import torch +from torchvision.transforms.functional import to_pil_image + +def resize_image(image, max_pixels, img_scale_num): + width, height = image.size + cur_pixels = height * width + ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(ratio, 1.0) # do not upscale input image + + new_height, new_width = int(height * ratio) // img_scale_num * img_scale_num, int(width * ratio) // img_scale_num * img_scale_num + + image = image.resize((new_width, new_height), resample=Image.BICUBIC) + return image + +def create_collage(images: List[torch.Tensor]) -> Image.Image: + """Create a horizontal collage from a list of images.""" + max_height = max(img.shape[-2] for img in images) + total_width = sum(img.shape[-1] for img in images) + canvas = torch.zeros((3, max_height, total_width), device=images[0].device) + + current_x = 0 + for img in images: + h, w = img.shape[-2:] + canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5 + current_x += w + + return to_pil_image(canvas) \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/omnigen2/src/utils/import_utils.py b/extensions_built_in/diffusion_models/omnigen2/src/utils/import_utils.py new file mode 100644 index 00000000..dc946d77 --- /dev/null +++ b/extensions_built_in/diffusion_models/omnigen2/src/utils/import_utils.py @@ -0,0 +1,46 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" + +import importlib.util +import sys + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + +_triton_available, _triton_version = _is_package_available("triton") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") + +def is_triton_available(): + return _triton_available + +def is_flash_attn_available(): + return _flash_attn_available \ No newline at end of file diff --git a/toolkit/extension.py b/toolkit/extension.py index 8d1f38e5..f4f10e9d 100644 --- a/toolkit/extension.py +++ b/toolkit/extension.py @@ -34,7 +34,7 @@ def get_all_extensions() -> List[Extension]: for sub_dir in extension_folders: extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir) for (_, name, _) in pkgutil.iter_modules([extensions_dir]): - try: + # try: # Import the module module = importlib.import_module(f"{sub_dir}.{name}") # Get the value of the AI_TOOLKIT_EXTENSIONS variable @@ -43,8 +43,8 @@ def get_all_extensions() -> List[Extension]: if isinstance(extensions, list): # Iterate over the list and add the classes to the main list all_extension_classes.extend(extensions) - except ImportError as e: - print(f"Failed to import the {name} module. Error: {str(e)}") + # except ImportError as e: + # print(f"Failed to import the {name} module. Error: {str(e)}") return all_extension_classes diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index e84bc4a0..2f4e235c 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -170,6 +170,19 @@ export const modelArchs: ModelArch[] = [ }, disableSections: ['model.quantize', 'train.timestep_type'], }, + { + name: 'omnigen2', + label: 'OmniGen2', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [true, false], + }, + disableSections: ['network.conv'], + }, ].sort((a, b) => { // Sort by label, case-insensitive return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }) diff --git a/version.py b/version.py index 0db14ed3..60504c97 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.3.1" \ No newline at end of file +VERSION = "0.3.2" \ No newline at end of file From 24cd94929e63566306b6f9c6afa41428664f657a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 25 Jun 2025 14:01:08 -0600 Subject: [PATCH 2/2] Fix bug that can happen with fast processing dataset --- toolkit/data_transfer_object/data_loader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 0c7d7562..9078fbbb 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -91,6 +91,8 @@ class FileItemDTO( except image_utils.UnknownImageFormat: print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ f'This process is faster for png, jpeg') + img = exif_transpose(Image.open(self.path)) + w, h = img.size else: img = exif_transpose(Image.open(self.path)) w, h = img.size