diff --git a/config/examples/train_lora_omnigen2_24gb.yaml b/config/examples/train_lora_omnigen2_24gb.yaml
new file mode 100644
index 00000000..6eb15302
--- /dev/null
+++ b/config/examples/train_lora_omnigen2_24gb.yaml
@@ -0,0 +1,94 @@
+---
+job: extension
+config:
+ # this name will be the folder and filename name
+ name: "my_first_omnigen2_lora_v1"
+ process:
+ - type: 'sd_trainer'
+ # root folder to save training sessions/samples/weights
+ training_folder: "output"
+ # uncomment to see performance stats in the terminal every N steps
+# performance_log_every: 1000
+ device: cuda:0
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
+# trigger_word: "p3r5on"
+ network:
+ type: "lora"
+ linear: 16
+ linear_alpha: 16
+ save:
+ dtype: float16 # precision to save
+ save_every: 250 # save every this many steps
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
+# hf_repo_id: your-username/your-model-slug
+# hf_private: true #whether the repo is private or public
+ datasets:
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
+ # images will automatically be resized and bucketed into the resolution specified
+ # on windows, escape back slashes with another backslash so
+ # "C:\\path\\to\\images\\folder"
+ - folder_path: "/path/to/images/folder"
+ caption_ext: "txt"
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
+ shuffle_tokens: false # shuffle caption order, split by commas
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
+ resolution: [ 512, 768, 1024 ] # omnigen2 should work with multiple resolutions
+ train:
+ batch_size: 1
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
+ gradient_accumulation: 1
+ train_unet: true
+ train_text_encoder: false # probably won't work with omnigen2
+ gradient_checkpointing: true # need the on unless you have a ton of vram
+ noise_scheduler: "flowmatch" # for training only
+ optimizer: "adamw8bit"
+ lr: 1e-4
+ timestep_type: 'sigmoid' # sigmoid, linear, shift
+ # uncomment this to skip the pre training sample
+# skip_first_sample: true
+ # uncomment to completely disable sampling
+# disable_sampling: true
+
+ # ema will smooth out learning, but could slow it down.
+ # ema_config:
+ # use_ema: true
+ # ema_decay: 0.99
+
+ # will probably need this if gpu supports it for omnigen2, other dtypes may not work correctly
+ dtype: bf16
+ model:
+ name_or_path: "OmniGen2/OmniGen2
+ arch: "omnigen2"
+ quantize_te: true # quantize_only te
+ # quantize: true # quantize transformer
+ sample:
+ sampler: "flowmatch" # must match train.noise_scheduler
+ sample_every: 250 # sample every this many steps
+ width: 1024
+ height: 1024
+ prompts:
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
+# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
+ - "a bear building a log cabin in the snow covered mountains"
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
+ - "hipster man with a beard, building a chair, in a wood shop"
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
+ - "a man holding a sign that says, 'this is a sign'"
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
+ neg: "" # negative prompt, optional
+ seed: 42
+ walk_seed: true
+ guidance_scale: 4
+ sample_steps: 25
+# you can add any additional meta info here. [name] is replaced with config name at top
+meta:
+ name: "[name]"
+ version: '1.0'
diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py
index 0cc323b9..3faa87cc 100644
--- a/extensions_built_in/diffusion_models/__init__.py
+++ b/extensions_built_in/diffusion_models/__init__.py
@@ -1,8 +1,9 @@
from .chroma import ChromaModel
from .hidream import HidreamModel
from .f_light import FLiteModel
+from .omnigen2 import OmniGen2Model
AI_TOOLKIT_MODELS = [
# put a list of models here
- ChromaModel, HidreamModel, FLiteModel
+ ChromaModel, HidreamModel, FLiteModel, OmniGen2Model
]
diff --git a/extensions_built_in/diffusion_models/omnigen2/__init__.py b/extensions_built_in/diffusion_models/omnigen2/__init__.py
new file mode 100644
index 00000000..5b20b387
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/__init__.py
@@ -0,0 +1,327 @@
+import inspect
+import os
+from typing import TYPE_CHECKING, List, Optional
+
+import einops
+import torch
+import yaml
+from toolkit.config_modules import GenerateImageConfig, ModelConfig
+from toolkit.models.base_model import BaseModel
+from diffusers import AutoencoderKL
+from toolkit.basic import flush
+from toolkit.prompt_utils import PromptEmbeds
+from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
+from toolkit.accelerator import unwrap_model
+from optimum.quanto import freeze
+from toolkit.util.quantize import quantize, get_qtype
+from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
+from .src.models.transformers import OmniGen2Transformer2DModel
+from .src.models.transformers.repo import OmniGen2RotaryPosEmbed
+from .src.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler
+from transformers import CLIPProcessor, Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
+
+if TYPE_CHECKING:
+ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
+
+scheduler_config = {
+ "num_train_timesteps": 1000
+}
+
+BASE_MODEL_PATH = "OmniGen2/OmniGen2"
+
+
+class OmniGen2Model(BaseModel):
+ arch = "omnigen2"
+
+ def __init__(
+ self,
+ device,
+ model_config: ModelConfig,
+ dtype='bf16',
+ custom_pipeline=None,
+ noise_scheduler=None,
+ **kwargs
+ ):
+ super().__init__(
+ device,
+ model_config,
+ dtype,
+ custom_pipeline,
+ noise_scheduler,
+ **kwargs
+ )
+ self.is_flow_matching = True
+ self.is_transformer = True
+ self.target_lora_modules = ['OmniGen2Transformer2DModel']
+
+ # static method to get the noise scheduler
+ @staticmethod
+ def get_train_scheduler():
+ return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
+
+ def get_bucket_divisibility(self):
+ return 16
+
+ def load_model(self):
+ dtype = self.torch_dtype
+ # HiDream-ai/HiDream-I1-Full
+ self.print_and_status_update("Loading OmniGen2 model")
+ # will be updated if we detect a existing checkpoint in training folder
+ model_path = self.model_config.name_or_path
+ extras_path = self.model_config.extras_name_or_path
+
+ scheduler = OmniGen2Model.get_train_scheduler()
+
+ self.print_and_status_update("Loading Qwen2.5 VL")
+ processor = CLIPProcessor.from_pretrained(
+ extras_path,
+ subfolder="processor",
+ use_fast=True
+ )
+
+ mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ extras_path,
+ subfolder="mllm",
+ torch_dtype=torch.bfloat16
+ )
+
+ if self.model_config.quantize_te:
+ self.print_and_status_update("Quantizing Qwen2.5 VL model")
+ quantization_type = get_qtype(self.model_config.qtype_te)
+ quantize(mllm, weights=quantization_type)
+ freeze(mllm)
+
+ if self.low_vram:
+ # unload it for now
+ mllm.to('cpu')
+
+ flush()
+
+ self.print_and_status_update("Loading transformer")
+
+ transformer = OmniGen2Transformer2DModel.from_pretrained(
+ model_path,
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16
+ )
+
+ if not self.low_vram:
+ transformer.to(self.device_torch, dtype=dtype)
+
+ if self.model_config.quantize:
+ self.print_and_status_update("Quantizing transformer")
+ quantization_type = get_qtype(self.model_config.qtype)
+ quantize(transformer, weights=quantization_type)
+ freeze(transformer)
+
+ if self.low_vram:
+ # unload it for now
+ transformer.to('cpu')
+
+ flush()
+
+ self.print_and_status_update("Loading vae")
+
+ vae = AutoencoderKL.from_pretrained(
+ extras_path,
+ subfolder="vae",
+ torch_dtype=torch.bfloat16
+ ).to(self.device_torch, dtype=dtype)
+
+
+ flush()
+ self.print_and_status_update("Loading Qwen2.5 VLProcessor")
+
+ flush()
+
+ if self.low_vram:
+ self.print_and_status_update("Moving everything to device")
+ # move it all back
+ transformer.to(self.device_torch, dtype=dtype)
+ vae.to(self.device_torch, dtype=dtype)
+ mllm.to(self.device_torch, dtype=dtype)
+
+ # set to eval mode
+ # transformer.eval()
+ vae.eval()
+ mllm.eval()
+ mllm.requires_grad_(False)
+
+ pipe: OmniGen2Pipeline = OmniGen2Pipeline(
+ transformer=transformer,
+ vae=vae,
+ scheduler=scheduler,
+ mllm=mllm,
+ processor=processor,
+ )
+
+ # pipe: OmniGen2Pipeline = OmniGen2Pipeline.from_pretrained(
+ # model_path,
+ # transformer=transformer,
+ # vae=vae,
+ # scheduler=scheduler,
+ # mllm=mllm,
+ # trust_remote_code=True,
+ # )
+ # processor = pipe.processor
+
+ flush()
+
+ text_encoder_list = [mllm]
+ tokenizer_list = [processor]
+
+
+ flush()
+
+ # save it to the model class
+ self.vae = vae
+ self.text_encoder = text_encoder_list # list of text encoders
+ self.tokenizer = tokenizer_list # list of tokenizers
+ self.model = pipe.transformer
+ self.pipeline = pipe
+
+ self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
+ transformer.config.axes_dim_rope,
+ transformer.config.axes_lens,
+ theta=10000,
+ )
+
+ self.print_and_status_update("Model Loaded")
+
+ def get_generation_pipeline(self):
+ scheduler = OmniFlowMatchEuler(
+ dynamic_time_shift=True,
+ num_train_timesteps=1000
+ )
+
+ pipeline: OmniGen2Pipeline = OmniGen2Pipeline(
+ transformer=self.model,
+ vae=self.vae,
+ scheduler=scheduler,
+ mllm=self.text_encoder[0],
+ processor=self.tokenizer[0],
+ )
+
+ pipeline = pipeline.to(self.device_torch)
+
+ return pipeline
+
+ def generate_single_image(
+ self,
+ pipeline: OmniGen2Pipeline,
+ gen_config: GenerateImageConfig,
+ conditional_embeds: PromptEmbeds,
+ unconditional_embeds: PromptEmbeds,
+ generator: torch.Generator,
+ extra: dict,
+ ):
+ img = pipeline(
+ prompt_embeds=conditional_embeds.text_embeds,
+ prompt_attention_mask=conditional_embeds.attention_mask,
+ negative_prompt_embeds=unconditional_embeds.text_embeds,
+ negative_prompt_attention_mask=unconditional_embeds.attention_mask,
+ height=gen_config.height,
+ width=gen_config.width,
+ num_inference_steps=gen_config.num_inference_steps,
+ text_guidance_scale=gen_config.guidance_scale,
+ image_guidance_scale=1.0, # reference image guidance scale. Add this for controls
+ latents=gen_config.latents,
+ generator=generator,
+ **extra
+ ).images[0]
+ return img
+
+ def get_noise_prediction(
+ self,
+ latent_model_input: torch.Tensor,
+ timestep: torch.Tensor, # 0 to 1000 scale
+ text_embeddings: PromptEmbeds,
+ **kwargs
+ ):
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
+
+ # optional_kwargs = {}
+ # if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()):
+ # optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
+
+ timesteps = timestep / 1000 # convert to 0 to 1 scale
+ # timestep for model starts at 0 instead of 1. So we need to reverse them
+ timestep = 1 - timesteps
+ model_pred = self.model(
+ latent_model_input,
+ timestep,
+ text_embeddings.text_embeds,
+ self.freqs_cis,
+ text_embeddings.attention_mask,
+ ref_image_hidden_states=None, # todo add ref latent ability
+ )
+
+ return model_pred
+
+ def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt]
+ self.text_encoder_to(self.device_torch, dtype=self.torch_dtype)
+ max_sequence_length = 256
+ prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
+ prompt = prompt,
+ do_classifier_free_guidance=False,
+ device=self.device_torch,
+ max_sequence_length=max_sequence_length,
+ )
+ pe = PromptEmbeds(prompt_embeds)
+ pe.attention_mask = prompt_attention_mask
+ return pe
+
+ def get_model_has_grad(self):
+ # return from a weight if it has grad
+ return False
+
+ def get_te_has_grad(self):
+ # assume no one wants to finetune 4 text encoders.
+ return False
+
+ def save_model(self, output_path, meta, save_dtype):
+ # only save the transformer
+ transformer: OmniGen2Transformer2DModel = unwrap_model(self.model)
+ transformer.save_pretrained(
+ save_directory=os.path.join(output_path, 'transformer'),
+ safe_serialization=True,
+ )
+
+ meta_path = os.path.join(output_path, 'aitk_meta.yaml')
+ with open(meta_path, 'w') as f:
+ yaml.dump(meta, f)
+
+ def get_loss_target(self, *args, **kwargs):
+ noise = kwargs.get('noise')
+ batch = kwargs.get('batch')
+ # return (noise - batch.latents).detach()
+ return (batch.latents - noise).detach()
+
+ def get_transformer_block_names(self) -> Optional[List[str]]:
+ # omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers.
+ # lets do all but image refiner until we add it
+ return ['noise_refiner', 'context_refiner', 'layers']
+ # return ['layers']
+
+ def convert_lora_weights_before_save(self, state_dict):
+ # currently starte with transformer. but needs to start with diffusion_model. for comfyui
+ new_sd = {}
+ for key, value in state_dict.items():
+ new_key = key.replace("transformer.", "diffusion_model.")
+ new_sd[new_key] = value
+ return new_sd
+
+ def convert_lora_weights_before_load(self, state_dict):
+ # saved as diffusion_model. but needs to be transformer. for ai-toolkit
+ new_sd = {}
+ for key, value in state_dict.items():
+ new_key = key.replace("diffusion_model.", "transformer.")
+ new_sd[new_key] = value
+ return new_sd
+
+ def get_base_model_version(self):
+ return "omnigen2"
+
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/attention_processor.py b/extensions_built_in/diffusion_models/omnigen2/src/models/attention_processor.py
new file mode 100644
index 00000000..1f713c75
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/models/attention_processor.py
@@ -0,0 +1,357 @@
+"""
+OmniGen2 Attention Processor Module
+
+Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import warnings
+import math
+from typing import Optional, Tuple, Dict, Any
+
+import torch
+import torch.nn.functional as F
+from einops import repeat
+
+from ..utils.import_utils import is_flash_attn_available
+
+if is_flash_attn_available():
+ from flash_attn import flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
+else:
+ warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance")
+
+
+from diffusers.models.attention_processor import Attention
+from .embeddings import apply_rotary_emb
+
+
+class OmniGen2AttnProcessorFlash2Varlen:
+ """
+ Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
+
+ This processor implements:
+ - Flash attention with variable length sequences
+ - Rotary position embeddings (RoPE)
+ - Query-Key normalization
+ - Proportional attention scaling
+
+ Args:
+ None
+ """
+
+ def __init__(self) -> None:
+ """Initialize the attention processor."""
+ if not is_flash_attn_available():
+ raise ImportError(
+ "OmniGen2AttnProcessorFlash2Varlen requires flash_attn. "
+ "Please install flash_attn."
+ )
+
+ def _upad_input(
+ self,
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ attention_mask: torch.Tensor,
+ query_length: int,
+ num_heads: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
+ """
+ Unpad the input tensors for flash attention.
+
+ Args:
+ query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
+ key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
+ value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
+ attention_mask: Attention mask tensor of shape (batch_size, seq_len)
+ query_length: Length of the query sequence
+ num_heads: Number of attention heads
+
+ Returns:
+ Tuple containing:
+ - Unpadded query tensor
+ - Unpadded key tensor
+ - Unpadded value tensor
+ - Query indices
+ - Tuple of cumulative sequence lengths for query and key
+ - Tuple of maximum sequence lengths for query and key
+ """
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ """Helper function to get unpadding data from attention mask."""
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return indices, cu_seqlens, max_seqlen_in_batch
+
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ # Unpad key and value layers
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
+ indices_k,
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
+ indices_k,
+ )
+
+ # Handle different query length cases
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
+ indices_k,
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ )
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ base_sequence_length: Optional[int] = None,
+ ) -> torch.Tensor:
+ """
+ Process attention computation with flash attention.
+
+ Args:
+ attn: Attention module
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
+ encoder_hidden_states: Encoder hidden states tensor
+ attention_mask: Optional attention mask tensor
+ image_rotary_emb: Optional rotary embeddings for image tokens
+ base_sequence_length: Optional base sequence length for proportional attention
+
+ Returns:
+ torch.Tensor: Processed hidden states after attention computation
+ """
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ # Get Query-Key-Value Pair
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query_dim = query.shape[-1]
+ inner_dim = key.shape[-1]
+ head_dim = query_dim // attn.heads
+ dtype = query.dtype
+
+ # Get key-value heads
+ kv_heads = inner_dim // head_dim
+
+ # Reshape tensors for attention computation
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, kv_heads, head_dim)
+ value = value.view(batch_size, -1, kv_heads, head_dim)
+
+ # Apply Query-Key normalization
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply Rotary Position Embeddings
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
+
+ query, key = query.to(dtype), key.to(dtype)
+
+ # Calculate attention scale
+ if base_sequence_length is not None:
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
+ else:
+ softmax_scale = attn.scale
+
+ # Unpad input for flash attention
+ (
+ query_states,
+ key_states,
+ value_states,
+ indices_q,
+ cu_seq_lens,
+ max_seq_lens,
+ ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ # Handle different number of heads
+ if kv_heads < attn.heads:
+ key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
+ value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
+
+ # Apply flash attention
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=0.0,
+ causal=False,
+ softmax_scale=softmax_scale,
+ )
+
+ # Pad output and apply final transformations
+ hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
+ hidden_states = hidden_states.flatten(-2)
+ hidden_states = hidden_states.type_as(query)
+
+ # Apply output projection
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class OmniGen2AttnProcessor:
+ """
+ Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
+
+ This processor is optimized for PyTorch 2.0 and implements:
+ - Flash attention with variable length sequences
+ - Rotary position embeddings (RoPE)
+ - Query-Key normalization
+ - Proportional attention scaling
+
+ Args:
+ None
+
+ Raises:
+ ImportError: If PyTorch version is less than 2.0
+ """
+
+ def __init__(self) -> None:
+ """Initialize the attention processor."""
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. "
+ "Please upgrade PyTorch to version 2.0 or later."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ base_sequence_length: Optional[int] = None,
+ ) -> torch.Tensor:
+ """
+ Process attention computation with flash attention.
+
+ Args:
+ attn: Attention module
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
+ encoder_hidden_states: Encoder hidden states tensor
+ attention_mask: Optional attention mask tensor
+ image_rotary_emb: Optional rotary embeddings for image tokens
+ base_sequence_length: Optional base sequence length for proportional attention
+
+ Returns:
+ torch.Tensor: Processed hidden states after attention computation
+ """
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ # Get Query-Key-Value Pair
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query_dim = query.shape[-1]
+ inner_dim = key.shape[-1]
+ head_dim = query_dim // attn.heads
+ dtype = query.dtype
+
+ # Get key-value heads
+ kv_heads = inner_dim // head_dim
+
+ # Reshape tensors for attention computation
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, kv_heads, head_dim)
+ value = value.view(batch_size, -1, kv_heads, head_dim)
+
+ # Apply Query-Key normalization
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply Rotary Position Embeddings
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
+
+ query, key = query.to(dtype), key.to(dtype)
+
+ # Calculate attention scale
+ if base_sequence_length is not None:
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
+ else:
+ softmax_scale = attn.scale
+
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ if attention_mask is not None:
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.type_as(query)
+
+ # Apply output projection
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/embeddings.py b/extensions_built_in/diffusion_models/omnigen2/src/models/embeddings.py
new file mode 100644
index 00000000..5282f2de
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/models/embeddings.py
@@ -0,0 +1,126 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+
+from diffusers.models.activations import get_activation
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim=None,
+ sample_proj_bias=True,
+ ):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
+
+ if cond_proj_dim is not None:
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
+ else:
+ self.cond_proj = None
+
+ self.act = get_activation(act_fn)
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
+
+ if post_act_fn is None:
+ self.post_act = None
+ else:
+ self.post_act = get_activation(post_act_fn)
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ nn.init.normal_(self.linear_1.weight, std=0.02)
+ nn.init.zeros_(self.linear_1.bias)
+ nn.init.normal_(self.linear_2.weight, std=0.02)
+ nn.init.zeros_(self.linear_2.bias)
+
+ def forward(self, sample, condition=None):
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+def apply_rotary_emb(
+ x: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ use_real: bool = True,
+ use_real_unbind_dim: int = -1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ if use_real:
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, None]
+ sin = sin[None, None]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+
+ if use_real_unbind_dim == -1:
+ # Used for flux, cogvideox, hunyuan-dit
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ elif use_real_unbind_dim == -2:
+ # Used for Stable Audio, OmniGen and CogView4
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
+ else:
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
+
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+
+ return out
+ else:
+ # used for lumina
+ # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
+ freqs_cis = freqs_cis.unsqueeze(2)
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
+
+ return x_out.type_as(x)
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/__init__.py
new file mode 100644
index 00000000..157de1d4
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/__init__.py
@@ -0,0 +1,3 @@
+from .transformer_omnigen2 import OmniGen2Transformer2DModel
+
+__all__ = ["OmniGen2Transformer2DModel"]
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/block_lumina2.py b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/block_lumina2.py
new file mode 100644
index 00000000..13739d3a
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/block_lumina2.py
@@ -0,0 +1,218 @@
+
+# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from diffusers.models.embeddings import Timesteps
+from ..embeddings import TimestepEmbedding
+
+from ...utils.import_utils import is_flash_attn_available, is_triton_available
+
+if is_triton_available():
+ from ...ops.triton.layer_norm import RMSNorm
+else:
+ from torch.nn import RMSNorm
+ warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance")
+
+if is_flash_attn_available():
+ from flash_attn.ops.activations import swiglu
+else:
+ from .components import swiglu
+ warnings.warn("Cannot import flash_attn, install flash_attn to use fused SwiGLU for better performance")
+
+# try:
+# from flash_attn.ops.activations import swiglu as fused_swiglu
+# FUSEDSWIGLU_AVALIBLE = True
+# except ImportError:
+
+# FUSEDSWIGLU_AVALIBLE = False
+# warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
+
+class LuminaRMSNormZero(nn.Module):
+ """
+ Norm layer adaptive RMS normalization zero.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ norm_eps: float,
+ norm_elementwise_affine: bool,
+ ):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(
+ min(embedding_dim, 1024),
+ 4 * embedding_dim,
+ bias=True,
+ )
+
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(self.silu(emb))
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None])
+ return x, gate_msa, scale_mlp, gate_mlp
+
+
+class LuminaLayerNormContinuous(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
+ # However, this is how it was implemented in the original code, and it's rather likely you should
+ # set `elementwise_affine` to False.
+ elementwise_affine=True,
+ eps=1e-5,
+ bias=True,
+ norm_type="layer_norm",
+ out_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # AdaLN
+ self.silu = nn.SiLU()
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
+
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ self.linear_2 = None
+ if out_dim is not None:
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ conditioning_embedding: torch.Tensor,
+ ) -> torch.Tensor:
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
+ scale = emb
+ x = self.norm(x) * (1 + scale)[:, None, :]
+
+ if self.linear_2 is not None:
+ x = self.linear_2(x)
+
+ return x
+
+
+class LuminaFeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ hidden_size (`int`):
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
+ hidden representations.
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
+ of this value.
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
+ dimension. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ inner_dim: int,
+ multiple_of: Optional[int] = 256,
+ ffn_dim_multiplier: Optional[float] = None,
+ ):
+ super().__init__()
+ self.swiglu = swiglu
+
+ # custom hidden_size factor multiplier
+ if ffn_dim_multiplier is not None:
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
+
+ self.linear_1 = nn.Linear(
+ dim,
+ inner_dim,
+ bias=False,
+ )
+ self.linear_2 = nn.Linear(
+ inner_dim,
+ dim,
+ bias=False,
+ )
+ self.linear_3 = nn.Linear(
+ dim,
+ inner_dim,
+ bias=False,
+ )
+
+ def forward(self, x):
+ h1, h2 = self.linear_1(x), self.linear_3(x)
+ return self.linear_2(self.swiglu(h1, h2))
+
+
+class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ text_feat_dim: int = 2048,
+ frequency_embedding_size: int = 256,
+ norm_eps: float = 1e-5,
+ timestep_scale: float = 1.0,
+ ) -> None:
+ super().__init__()
+
+ self.time_proj = Timesteps(
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
+ )
+
+ self.timestep_embedder = TimestepEmbedding(
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
+ )
+
+ self.caption_embedder = nn.Sequential(
+ RMSNorm(text_feat_dim, eps=norm_eps),
+ nn.Linear(text_feat_dim, hidden_size, bias=True),
+ )
+
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
+ nn.init.zeros_(self.caption_embedder[1].bias)
+
+ def forward(
+ self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ timestep_proj = self.time_proj(timestep).to(dtype=dtype)
+ time_embed = self.timestep_embedder(timestep_proj)
+ caption_embed = self.caption_embedder(text_hidden_states)
+ return time_embed, caption_embed
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/components.py b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/components.py
new file mode 100644
index 00000000..5e654b8c
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/components.py
@@ -0,0 +1,4 @@
+import torch.nn.functional as F
+
+def swiglu(x, y):
+ return F.silu(x.float(), inplace=False).to(x.dtype) * y
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/repo.py b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/repo.py
new file mode 100644
index 00000000..ea565bf8
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/repo.py
@@ -0,0 +1,135 @@
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+
+from einops import repeat
+from diffusers.models.embeddings import get_1d_rotary_pos_embed
+
+class OmniGen2RotaryPosEmbed(nn.Module):
+ def __init__(self, theta: int,
+ axes_dim: Tuple[int, int, int],
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
+ patch_size: int = 2):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+ self.axes_lens = axes_lens
+ self.patch_size = patch_size
+
+ @staticmethod
+ def get_freqs_cis(axes_dim: Tuple[int, int, int],
+ axes_lens: Tuple[int, int, int],
+ theta: int) -> List[torch.Tensor]:
+ freqs_cis = []
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
+ emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
+ freqs_cis.append(emb)
+ return freqs_cis
+
+ def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
+ device = ids.device
+ if ids.device.type == "mps":
+ ids = ids.to("cpu")
+
+ result = []
+ for i in range(len(self.axes_dim)):
+ freqs = freqs_cis[i].to(ids.device)
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
+ return torch.cat(result, dim=-1).to(device)
+
+ def forward(
+ self,
+ freqs_cis,
+ attention_mask,
+ l_effective_ref_img_len,
+ l_effective_img_len,
+ ref_img_sizes,
+ img_sizes,
+ device
+ ):
+ batch_size = len(attention_mask)
+ p = self.patch_size
+
+ encoder_seq_len = attention_mask.shape[1]
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
+
+ seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
+
+ max_seq_len = int(max(seq_lengths))
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
+ max_img_len = max(l_effective_img_len)
+
+ # Create position IDs
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
+
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
+ cap_seq_len = int(cap_seq_len)
+ seq_len = int(seq_len)
+ # add text position ids
+ position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
+
+ pe_shift = cap_seq_len
+ pe_shift_len = cap_seq_len
+
+ if ref_img_sizes[i] is not None:
+ for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
+ H, W = ref_img_size
+ ref_H_tokens, ref_W_tokens = H // p, W // p
+ assert ref_H_tokens * ref_W_tokens == ref_img_len
+ # add image position ids
+
+ row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
+ col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
+
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
+ pe_shift_len += ref_img_len
+
+ H, W = img_sizes[i]
+ H_tokens, W_tokens = H // p, W // p
+ assert H_tokens * W_tokens == l_effective_img_len[i]
+
+ row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
+ col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
+
+ assert pe_shift_len + l_effective_img_len[i] == seq_len
+ position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
+ position_ids[i, pe_shift_len: seq_len, 1] = row_ids
+ position_ids[i, pe_shift_len: seq_len, 2] = col_ids
+
+ # Get combined rotary embeddings
+ freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
+
+ # create separate rotary embeddings for captions and images
+ cap_freqs_cis = torch.zeros(
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
+ )
+ ref_img_freqs_cis = torch.zeros(
+ batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
+ )
+ img_freqs_cis = torch.zeros(
+ batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
+ )
+
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
+ cap_seq_len = int(cap_seq_len)
+ sum_ref_img_len = int(sum(ref_img_len))
+ img_len = int(img_len)
+ seq_len = int(seq_len)
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
+ ref_img_freqs_cis[i, :sum_ref_img_len] = freqs_cis[i, cap_seq_len:cap_seq_len + sum_ref_img_len]
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum_ref_img_len:cap_seq_len + sum_ref_img_len + img_len]
+
+ return (
+ cap_freqs_cis,
+ ref_img_freqs_cis,
+ img_freqs_cis,
+ freqs_cis,
+ l_effective_cap_len,
+ seq_lengths,
+ )
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/transformer_omnigen2.py b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/transformer_omnigen2.py
new file mode 100644
index 00000000..8e7fef68
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/transformer_omnigen2.py
@@ -0,0 +1,621 @@
+import warnings
+import itertools
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from einops import rearrange
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.loaders.single_file_model import FromOriginalModelMixin
+from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.attention_processor import Attention
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+
+from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen, OmniGen2AttnProcessor
+from .repo import OmniGen2RotaryPosEmbed
+from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding
+
+from ...utils.import_utils import is_triton_available, is_flash_attn_available
+
+if is_triton_available():
+ from ...ops.triton.layer_norm import RMSNorm
+else:
+ from torch.nn import RMSNorm
+
+logger = logging.get_logger(__name__)
+
+
+class OmniGen2TransformerBlock(nn.Module):
+ """
+ Transformer block for OmniGen2 model.
+
+ This block implements a transformer layer with:
+ - Multi-head attention with flash attention
+ - Feed-forward network with SwiGLU activation
+ - RMS normalization
+ - Optional modulation for conditional generation
+
+ Args:
+ dim: Dimension of the input and output tensors
+ num_attention_heads: Number of attention heads
+ num_kv_heads: Number of key-value heads
+ multiple_of: Multiple of which the hidden dimension should be
+ ffn_dim_multiplier: Multiplier for the feed-forward network dimension
+ norm_eps: Epsilon value for normalization layers
+ modulation: Whether to use modulation for conditional generation
+ use_fused_rms_norm: Whether to use fused RMS normalization
+ use_fused_swiglu: Whether to use fused SwiGLU activation
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ num_kv_heads: int,
+ multiple_of: int,
+ ffn_dim_multiplier: float,
+ norm_eps: float,
+ modulation: bool = True,
+ ) -> None:
+ """Initialize the transformer block."""
+ super().__init__()
+ self.head_dim = dim // num_attention_heads
+ self.modulation = modulation
+
+ try:
+ processor = OmniGen2AttnProcessorFlash2Varlen()
+ except ImportError:
+ processor = OmniGen2AttnProcessor()
+
+ # Initialize attention layer
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=dim // num_attention_heads,
+ qk_norm="rms_norm",
+ heads=num_attention_heads,
+ kv_heads=num_kv_heads,
+ eps=1e-5,
+ bias=False,
+ out_bias=False,
+ processor=processor,
+ )
+
+ # Initialize feed-forward network
+ self.feed_forward = LuminaFeedForward(
+ dim=dim,
+ inner_dim=4 * dim,
+ multiple_of=multiple_of,
+ ffn_dim_multiplier=ffn_dim_multiplier
+ )
+
+ # Initialize normalization layers
+ if modulation:
+ self.norm1 = LuminaRMSNormZero(
+ embedding_dim=dim,
+ norm_eps=norm_eps,
+ norm_elementwise_affine=True
+ )
+ else:
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
+
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
+
+ self.initialize_weights()
+
+ def initialize_weights(self) -> None:
+ """
+ Initialize the weights of the transformer block.
+
+ Uses Xavier uniform initialization for linear layers and zero initialization for biases.
+ """
+ nn.init.xavier_uniform_(self.attn.to_q.weight)
+ nn.init.xavier_uniform_(self.attn.to_k.weight)
+ nn.init.xavier_uniform_(self.attn.to_v.weight)
+ nn.init.xavier_uniform_(self.attn.to_out[0].weight)
+
+ nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
+ nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
+ nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
+
+ if self.modulation:
+ nn.init.zeros_(self.norm1.linear.weight)
+ nn.init.zeros_(self.norm1.linear.bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ image_rotary_emb: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Forward pass of the transformer block.
+
+ Args:
+ hidden_states: Input hidden states tensor
+ attention_mask: Attention mask tensor
+ image_rotary_emb: Rotary embeddings for image tokens
+ temb: Optional timestep embedding tensor
+
+ Returns:
+ torch.Tensor: Output hidden states after transformer block processing
+ """
+ import time
+ if self.modulation:
+ if temb is None:
+ raise ValueError("temb must be provided when modulation is enabled")
+
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + self.norm2(attn_output)
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
+
+ return hidden_states
+
+
+class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ """
+ OmniGen2 Transformer 2D Model.
+
+ A transformer-based diffusion model for image generation with:
+ - Patch-based image processing
+ - Rotary position embeddings
+ - Multi-head attention
+ - Conditional generation support
+
+ Args:
+ patch_size: Size of image patches
+ in_channels: Number of input channels
+ out_channels: Number of output channels (defaults to in_channels)
+ hidden_size: Size of hidden layers
+ num_layers: Number of transformer layers
+ num_refiner_layers: Number of refiner layers
+ num_attention_heads: Number of attention heads
+ num_kv_heads: Number of key-value heads
+ multiple_of: Multiple of which the hidden dimension should be
+ ffn_dim_multiplier: Multiplier for feed-forward network dimension
+ norm_eps: Epsilon value for normalization layers
+ axes_dim_rope: Dimensions for rotary position embeddings
+ axes_lens: Lengths for rotary position embeddings
+ text_feat_dim: Dimension of text features
+ timestep_scale: Scale factor for timestep embeddings
+ use_fused_rms_norm: Whether to use fused RMS normalization
+ use_fused_swiglu: Whether to use fused SwiGLU activation
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["Omnigen2TransformerBlock"]
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ out_channels: Optional[int] = None,
+ hidden_size: int = 2304,
+ num_layers: int = 26,
+ num_refiner_layers: int = 2,
+ num_attention_heads: int = 24,
+ num_kv_heads: int = 8,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: Optional[float] = None,
+ norm_eps: float = 1e-5,
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
+ text_feat_dim: int = 1024,
+ timestep_scale: float = 1.0
+ ) -> None:
+ """Initialize the OmniGen2 transformer model."""
+ super().__init__()
+
+ # Validate configuration
+ if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
+ raise ValueError(
+ f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
+ f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
+ )
+
+ self.out_channels = out_channels or in_channels
+
+ # Initialize embeddings
+ self.rope_embedder = OmniGen2RotaryPosEmbed(
+ theta=10000,
+ axes_dim=axes_dim_rope,
+ axes_lens=axes_lens,
+ patch_size=patch_size,
+ )
+
+ self.x_embedder = nn.Linear(
+ in_features=patch_size * patch_size * in_channels,
+ out_features=hidden_size,
+ )
+
+ self.ref_image_patch_embedder = nn.Linear(
+ in_features=patch_size * patch_size * in_channels,
+ out_features=hidden_size,
+ )
+
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
+ hidden_size=hidden_size,
+ text_feat_dim=text_feat_dim,
+ norm_eps=norm_eps,
+ timestep_scale=timestep_scale
+ )
+
+ # Initialize transformer blocks
+ self.noise_refiner = nn.ModuleList([
+ OmniGen2TransformerBlock(
+ hidden_size,
+ num_attention_heads,
+ num_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ modulation=True
+ )
+ for _ in range(num_refiner_layers)
+ ])
+
+ self.ref_image_refiner = nn.ModuleList([
+ OmniGen2TransformerBlock(
+ hidden_size,
+ num_attention_heads,
+ num_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ modulation=True
+ )
+ for _ in range(num_refiner_layers)
+ ])
+
+ self.context_refiner = nn.ModuleList(
+ [
+ OmniGen2TransformerBlock(
+ hidden_size,
+ num_attention_heads,
+ num_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ modulation=False
+ )
+ for _ in range(num_refiner_layers)
+ ]
+ )
+
+ # 3. Transformer blocks
+ self.layers = nn.ModuleList(
+ [
+ OmniGen2TransformerBlock(
+ hidden_size,
+ num_attention_heads,
+ num_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ modulation=True
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = LuminaLayerNormContinuous(
+ embedding_dim=hidden_size,
+ conditioning_embedding_dim=min(hidden_size, 1024),
+ elementwise_affine=False,
+ eps=1e-6,
+ bias=True,
+ out_dim=patch_size * patch_size * self.out_channels
+ )
+
+ # Add learnable embeddings to distinguish different images
+ self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
+
+ self.gradient_checkpointing = False
+
+ self.initialize_weights()
+
+ def initialize_weights(self) -> None:
+ """
+ Initialize the weights of the model.
+
+ Uses Xavier uniform initialization for linear layers.
+ """
+ nn.init.xavier_uniform_(self.x_embedder.weight)
+ nn.init.constant_(self.x_embedder.bias, 0.0)
+
+ nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
+ nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
+
+ nn.init.zeros_(self.norm_out.linear_1.weight)
+ nn.init.zeros_(self.norm_out.linear_1.bias)
+ nn.init.zeros_(self.norm_out.linear_2.weight)
+ nn.init.zeros_(self.norm_out.linear_2.bias)
+
+ nn.init.normal_(self.image_index_embedding, std=0.02)
+
+ def img_patch_embed_and_refine(
+ self,
+ hidden_states,
+ ref_image_hidden_states,
+ padded_img_mask,
+ padded_ref_img_mask,
+ noise_rotary_emb,
+ ref_img_rotary_emb,
+ l_effective_ref_img_len,
+ l_effective_img_len,
+ temb
+ ):
+ batch_size = len(hidden_states)
+ max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
+
+ hidden_states = self.x_embedder(hidden_states)
+ ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
+
+ for i in range(batch_size):
+ shift = 0
+ for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
+ shift += ref_img_len
+
+ for layer in self.noise_refiner:
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
+
+ flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
+ num_ref_images = len(flat_l_effective_ref_img_len)
+ max_ref_img_len = max(flat_l_effective_ref_img_len)
+
+ batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
+ batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
+ batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
+ batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
+
+ # sequence of ref imgs to batch
+ idx = 0
+ for i in range(batch_size):
+ shift = 0
+ for ref_img_len in l_effective_ref_img_len[i]:
+ batch_ref_img_mask[idx, :ref_img_len] = True
+ batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
+ batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
+ batch_temb[idx] = temb[i]
+ shift += ref_img_len
+ idx += 1
+
+ # refine ref imgs separately
+ for layer in self.ref_image_refiner:
+ batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
+
+ # batch of ref imgs to sequence
+ idx = 0
+ for i in range(batch_size):
+ shift = 0
+ for ref_img_len in l_effective_ref_img_len[i]:
+ ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
+ shift += ref_img_len
+ idx += 1
+
+ combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
+ for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
+ combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
+ combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
+
+ return combined_img_hidden_states
+
+ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
+ batch_size = len(hidden_states)
+ p = self.config.patch_size
+ device = hidden_states[0].device
+
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
+ l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
+
+ if ref_image_hidden_states is not None:
+ ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
+ l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
+ else:
+ ref_img_sizes = [None for _ in range(batch_size)]
+ l_effective_ref_img_len = [[0] for _ in range(batch_size)]
+
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
+ max_img_len = max(l_effective_img_len)
+
+ # ref image patch embeddings
+ flat_ref_img_hidden_states = []
+ for i in range(batch_size):
+ if ref_img_sizes[i] is not None:
+ imgs = []
+ for ref_img in ref_image_hidden_states[i]:
+ C, H, W = ref_img.size()
+ ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
+ imgs.append(ref_img)
+
+ img = torch.cat(imgs, dim=0)
+ flat_ref_img_hidden_states.append(img)
+ else:
+ flat_ref_img_hidden_states.append(None)
+
+ # image patch embeddings
+ flat_hidden_states = []
+ for i in range(batch_size):
+ img = hidden_states[i]
+ C, H, W = img.size()
+
+ img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
+ flat_hidden_states.append(img)
+
+ padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
+ padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
+ for i in range(batch_size):
+ if ref_img_sizes[i] is not None:
+ padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
+ padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
+
+ padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
+ padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
+ for i in range(batch_size):
+ padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
+ padded_img_mask[i, :l_effective_img_len[i]] = True
+
+ return (
+ padded_hidden_states,
+ padded_ref_img_hidden_states,
+ padded_img_mask,
+ padded_ref_img_mask,
+ l_effective_ref_img_len,
+ l_effective_img_len,
+ ref_img_sizes,
+ img_sizes,
+ )
+
+ def forward(
+ self,
+ hidden_states: Union[torch.Tensor, List[torch.Tensor]],
+ timestep: torch.Tensor,
+ text_hidden_states: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ text_attention_mask: torch.Tensor,
+ ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = False,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # 1. Condition, positional & patch embedding
+ batch_size = len(hidden_states)
+ is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
+
+ if is_hidden_states_tensor:
+ assert hidden_states.ndim == 4
+ hidden_states = [_hidden_states for _hidden_states in hidden_states]
+
+ device = hidden_states[0].device
+
+ temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
+
+ (
+ hidden_states,
+ ref_image_hidden_states,
+ img_mask,
+ ref_img_mask,
+ l_effective_ref_img_len,
+ l_effective_img_len,
+ ref_img_sizes,
+ img_sizes,
+ ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
+
+ (
+ context_rotary_emb,
+ ref_img_rotary_emb,
+ noise_rotary_emb,
+ rotary_emb,
+ encoder_seq_lengths,
+ seq_lengths,
+ ) = self.rope_embedder(
+ freqs_cis,
+ text_attention_mask,
+ l_effective_ref_img_len,
+ l_effective_img_len,
+ ref_img_sizes,
+ img_sizes,
+ device,
+ )
+
+ # 2. Context refinement
+ for layer in self.context_refiner:
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
+
+ combined_img_hidden_states = self.img_patch_embed_and_refine(
+ hidden_states,
+ ref_image_hidden_states,
+ img_mask,
+ ref_img_mask,
+ noise_rotary_emb,
+ ref_img_rotary_emb,
+ l_effective_ref_img_len,
+ l_effective_img_len,
+ temb,
+ )
+
+ # 3. Joint Transformer blocks
+ max_seq_len = int(max(seq_lengths))
+
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
+ encoder_seq_len = int(encoder_seq_len)
+ seq_len = int(seq_len)
+ attention_mask[i, :seq_len] = True
+ joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
+ joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
+
+ hidden_states = joint_hidden_states
+
+ for layer_idx, layer in enumerate(self.layers):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ layer, hidden_states, attention_mask, rotary_emb, temb
+ )
+ else:
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
+
+ # 4. Output norm & projection
+ hidden_states = self.norm_out(hidden_states, temb)
+
+ p = self.config.patch_size
+ output = []
+ for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
+ img_len = int(img_len)
+ seq_len = int(seq_len)
+ height, width = img_size
+ output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
+ if is_hidden_states_tensor:
+ output = torch.stack(output, dim=0)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return output
+ return Transformer2DModelOutput(sample=output)
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/layer_norm.py b/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/layer_norm.py
new file mode 100644
index 00000000..b7d12330
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/layer_norm.py
@@ -0,0 +1,1257 @@
+# Copyright (c) 2024, Tri Dao.
+# Implement dropout + residual + layer_norm / rms_norm.
+
+# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
+# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
+# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
+
+import math
+
+import torch
+import torch.nn.functional as F
+
+import triton
+import triton.language as tl
+
+
+from typing import Callable
+
+
+def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
+ def decorator(*args, **kwargs):
+ if cuda_amp_deprecated:
+ kwargs["device_type"] = "cuda"
+ return dec(*args, **kwargs)
+ return decorator
+
+
+if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
+ deprecated = True
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
+else:
+ deprecated = False
+ from torch.cuda.amp import custom_fwd, custom_bwd
+
+custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
+custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
+
+
+def triton_autotune_configs():
+ # Return configs with a valid warp count for the current device
+ configs=[]
+ # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
+ max_threads_per_block=1024
+ # Default to warp size 32 if not defined by device
+ warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
+ # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
+ warp_count=1
+ while warp_count*warp_size <= max_threads_per_block:
+ configs.append(triton.Config({}, num_warps=warp_count))
+ warp_count*=2
+ return configs
+
+def layer_norm_ref(
+ x,
+ weight,
+ bias,
+ residual=None,
+ x1=None,
+ weight1=None,
+ bias1=None,
+ eps=1e-6,
+ dropout_p=0.0,
+ rowscale=None,
+ prenorm=False,
+ zero_centered_weight=False,
+ dropout_mask=None,
+ dropout_mask1=None,
+ upcast=False,
+):
+ dtype = x.dtype
+ if upcast:
+ x = x.float()
+ weight = weight.float()
+ bias = bias.float() if bias is not None else None
+ residual = residual.float() if residual is not None else residual
+ x1 = x1.float() if x1 is not None else None
+ weight1 = weight1.float() if weight1 is not None else None
+ bias1 = bias1.float() if bias1 is not None else None
+ if zero_centered_weight:
+ weight = weight + 1.0
+ if weight1 is not None:
+ weight1 = weight1 + 1.0
+ if x1 is not None:
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
+ if rowscale is not None:
+ x = x * rowscale[..., None]
+ if dropout_p > 0.0:
+ if dropout_mask is not None:
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
+ else:
+ x = F.dropout(x, p=dropout_p)
+ if x1 is not None:
+ if dropout_mask1 is not None:
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
+ else:
+ x1 = F.dropout(x1, p=dropout_p)
+ if x1 is not None:
+ x = x + x1
+ if residual is not None:
+ x = (x + residual).to(x.dtype)
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
+ dtype
+ )
+ if weight1 is None:
+ return out if not prenorm else (out, x)
+ else:
+ out1 = F.layer_norm(
+ x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
+ ).to(dtype)
+ return (out, out1) if not prenorm else (out, out1, x)
+
+
+def rms_norm_ref(
+ x,
+ weight,
+ bias,
+ residual=None,
+ x1=None,
+ weight1=None,
+ bias1=None,
+ eps=1e-6,
+ dropout_p=0.0,
+ rowscale=None,
+ prenorm=False,
+ zero_centered_weight=False,
+ dropout_mask=None,
+ dropout_mask1=None,
+ upcast=False,
+):
+ dtype = x.dtype
+ if upcast:
+ x = x.float()
+ weight = weight.float()
+ bias = bias.float() if bias is not None else None
+ residual = residual.float() if residual is not None else residual
+ x1 = x1.float() if x1 is not None else None
+ weight1 = weight1.float() if weight1 is not None else None
+ bias1 = bias1.float() if bias1 is not None else None
+ if zero_centered_weight:
+ weight = weight + 1.0
+ if weight1 is not None:
+ weight1 = weight1 + 1.0
+ if x1 is not None:
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
+ if rowscale is not None:
+ x = x * rowscale[..., None]
+ if dropout_p > 0.0:
+ if dropout_mask is not None:
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
+ else:
+ x = F.dropout(x, p=dropout_p)
+ if x1 is not None:
+ if dropout_mask1 is not None:
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
+ else:
+ x1 = F.dropout(x1, p=dropout_p)
+ if x1 is not None:
+ x = x + x1
+ if residual is not None:
+ x = (x + residual).to(x.dtype)
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
+ if weight1 is None:
+ return out if not prenorm else (out, x)
+ else:
+ out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
+ dtype
+ )
+ return (out, out1) if not prenorm else (out, out1, x)
+
+
+@triton.autotune(
+ configs=triton_autotune_configs(),
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
+)
+# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
+# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
+@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
+@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
+@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
+@triton.jit
+def _layer_norm_fwd_1pass_kernel(
+ X, # pointer to the input
+ Y, # pointer to the output
+ W, # pointer to the weights
+ B, # pointer to the biases
+ RESIDUAL, # pointer to the residual
+ X1,
+ W1,
+ B1,
+ Y1,
+ RESIDUAL_OUT, # pointer to the residual
+ ROWSCALE,
+ SEEDS, # Dropout seeds for each row
+ DROPOUT_MASK,
+ Mean, # pointer to the mean
+ Rstd, # pointer to the 1/std
+ stride_x_row, # how much to increase the pointer when moving by 1 row
+ stride_y_row,
+ stride_res_row,
+ stride_res_out_row,
+ stride_x1_row,
+ stride_y1_row,
+ M, # number of rows in X
+ N, # number of columns in X
+ eps, # epsilon to avoid division by zero
+ dropout_p, # Dropout probability
+ zero_centered_weight, # If true, add 1.0 to the weight
+ IS_RMS_NORM: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ HAS_RESIDUAL: tl.constexpr,
+ STORE_RESIDUAL_OUT: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ HAS_DROPOUT: tl.constexpr,
+ STORE_DROPOUT_MASK: tl.constexpr,
+ HAS_ROWSCALE: tl.constexpr,
+ HAS_X1: tl.constexpr,
+ HAS_W1: tl.constexpr,
+ HAS_B1: tl.constexpr,
+):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ X += row * stride_x_row
+ Y += row * stride_y_row
+ if HAS_RESIDUAL:
+ RESIDUAL += row * stride_res_row
+ if STORE_RESIDUAL_OUT:
+ RESIDUAL_OUT += row * stride_res_out_row
+ if HAS_X1:
+ X1 += row * stride_x1_row
+ if HAS_W1:
+ Y1 += row * stride_y1_row
+ # Compute mean and variance
+ cols = tl.arange(0, BLOCK_N)
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
+ if HAS_ROWSCALE:
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
+ x *= rowscale
+ if HAS_DROPOUT:
+ # Compute dropout mask
+ # 7 rounds is good enough, and reduces register pressure
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
+ if STORE_DROPOUT_MASK:
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
+ if HAS_X1:
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
+ if HAS_ROWSCALE:
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
+ x1 *= rowscale
+ if HAS_DROPOUT:
+ # Compute dropout mask
+ # 7 rounds is good enough, and reduces register pressure
+ keep_mask = (
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
+ )
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
+ if STORE_DROPOUT_MASK:
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
+ x += x1
+ if HAS_RESIDUAL:
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
+ x += residual
+ if STORE_RESIDUAL_OUT:
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
+ if not IS_RMS_NORM:
+ mean = tl.sum(x, axis=0) / N
+ tl.store(Mean + row, mean)
+ xbar = tl.where(cols < N, x - mean, 0.0)
+ var = tl.sum(xbar * xbar, axis=0) / N
+ else:
+ xbar = tl.where(cols < N, x, 0.0)
+ var = tl.sum(xbar * xbar, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ tl.store(Rstd + row, rstd)
+ # Normalize and apply linear transformation
+ mask = cols < N
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
+ if zero_centered_weight:
+ w += 1.0
+ if HAS_BIAS:
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
+ # Write output
+ tl.store(Y + cols, y, mask=mask)
+ if HAS_W1:
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
+ if zero_centered_weight:
+ w1 += 1.0
+ if HAS_B1:
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
+ tl.store(Y1 + cols, y1, mask=mask)
+
+
+def _layer_norm_fwd(
+ x,
+ weight,
+ bias,
+ eps,
+ residual=None,
+ x1=None,
+ weight1=None,
+ bias1=None,
+ dropout_p=0.0,
+ rowscale=None,
+ out_dtype=None,
+ residual_dtype=None,
+ zero_centered_weight=False,
+ is_rms_norm=False,
+ return_dropout_mask=False,
+ out=None,
+ residual_out=None
+):
+ if residual is not None:
+ residual_dtype = residual.dtype
+ M, N = x.shape
+ assert x.stride(-1) == 1
+ if residual is not None:
+ assert residual.stride(-1) == 1
+ assert residual.shape == (M, N)
+ assert weight.shape == (N,)
+ assert weight.stride(-1) == 1
+ if bias is not None:
+ assert bias.stride(-1) == 1
+ assert bias.shape == (N,)
+ if x1 is not None:
+ assert x1.shape == x.shape
+ assert rowscale is None
+ assert x1.stride(-1) == 1
+ if weight1 is not None:
+ assert weight1.shape == (N,)
+ assert weight1.stride(-1) == 1
+ if bias1 is not None:
+ assert bias1.shape == (N,)
+ assert bias1.stride(-1) == 1
+ if rowscale is not None:
+ assert rowscale.is_contiguous()
+ assert rowscale.shape == (M,)
+ # allocate output
+ if out is None:
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
+ else:
+ assert out.shape == x.shape
+ assert out.stride(-1) == 1
+ if weight1 is not None:
+ y1 = torch.empty_like(out)
+ assert y1.stride(-1) == 1
+ else:
+ y1 = None
+ if (
+ residual is not None
+ or (residual_dtype is not None and residual_dtype != x.dtype)
+ or dropout_p > 0.0
+ or rowscale is not None
+ or x1 is not None
+ ):
+ if residual_out is None:
+ residual_out = torch.empty(
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
+ )
+ else:
+ assert residual_out.shape == x.shape
+ assert residual_out.stride(-1) == 1
+ else:
+ residual_out = None
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
+ if dropout_p > 0.0:
+ seeds = torch.randint(
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
+ )
+ else:
+ seeds = None
+ if return_dropout_mask and dropout_p > 0.0:
+ dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
+ else:
+ dropout_mask = None
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ if N > BLOCK_N:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+ with torch.cuda.device(x.device.index):
+ _layer_norm_fwd_1pass_kernel[(M,)](
+ x,
+ out,
+ weight,
+ bias,
+ residual,
+ x1,
+ weight1,
+ bias1,
+ y1,
+ residual_out,
+ rowscale,
+ seeds,
+ dropout_mask,
+ mean,
+ rstd,
+ x.stride(0),
+ out.stride(0),
+ residual.stride(0) if residual is not None else 0,
+ residual_out.stride(0) if residual_out is not None else 0,
+ x1.stride(0) if x1 is not None else 0,
+ y1.stride(0) if y1 is not None else 0,
+ M,
+ N,
+ eps,
+ dropout_p,
+ zero_centered_weight,
+ is_rms_norm,
+ BLOCK_N,
+ residual is not None,
+ residual_out is not None,
+ bias is not None,
+ dropout_p > 0.0,
+ dropout_mask is not None,
+ rowscale is not None,
+ )
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
+ if dropout_mask is not None and x1 is not None:
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
+ else:
+ dropout_mask1 = None
+ return (
+ out,
+ y1,
+ mean,
+ rstd,
+ residual_out if residual_out is not None else x,
+ seeds,
+ dropout_mask,
+ dropout_mask1,
+ )
+
+
+@triton.autotune(
+ configs=triton_autotune_configs(),
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
+)
+# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
+# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
+# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
+@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
+@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
+@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
+@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
+@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
+@triton.jit
+def _layer_norm_bwd_kernel(
+ X, # pointer to the input
+ W, # pointer to the weights
+ B, # pointer to the biases
+ Y, # pointer to the output to be recomputed
+ DY, # pointer to the output gradient
+ DX, # pointer to the input gradient
+ DW, # pointer to the partial sum of weights gradient
+ DB, # pointer to the partial sum of biases gradient
+ DRESIDUAL,
+ W1,
+ DY1,
+ DX1,
+ DW1,
+ DB1,
+ DRESIDUAL_IN,
+ ROWSCALE,
+ SEEDS,
+ Mean, # pointer to the mean
+ Rstd, # pointer to the 1/std
+ stride_x_row, # how much to increase the pointer when moving by 1 row
+ stride_y_row,
+ stride_dy_row,
+ stride_dx_row,
+ stride_dres_row,
+ stride_dy1_row,
+ stride_dx1_row,
+ stride_dres_in_row,
+ M, # number of rows in X
+ N, # number of columns in X
+ eps, # epsilon to avoid division by zero
+ dropout_p,
+ zero_centered_weight,
+ rows_per_program,
+ IS_RMS_NORM: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ HAS_DRESIDUAL: tl.constexpr,
+ STORE_DRESIDUAL: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ HAS_DROPOUT: tl.constexpr,
+ HAS_ROWSCALE: tl.constexpr,
+ HAS_DY1: tl.constexpr,
+ HAS_DX1: tl.constexpr,
+ HAS_B1: tl.constexpr,
+ RECOMPUTE_OUTPUT: tl.constexpr,
+):
+ # Map the program id to the elements of X, DX, and DY it should compute.
+ row_block_id = tl.program_id(0)
+ row_start = row_block_id * rows_per_program
+ # Do not early exit if row_start >= M, because we need to write DW and DB
+ cols = tl.arange(0, BLOCK_N)
+ mask = cols < N
+ X += row_start * stride_x_row
+ if HAS_DRESIDUAL:
+ DRESIDUAL += row_start * stride_dres_row
+ if STORE_DRESIDUAL:
+ DRESIDUAL_IN += row_start * stride_dres_in_row
+ DY += row_start * stride_dy_row
+ DX += row_start * stride_dx_row
+ if HAS_DY1:
+ DY1 += row_start * stride_dy1_row
+ if HAS_DX1:
+ DX1 += row_start * stride_dx1_row
+ if RECOMPUTE_OUTPUT:
+ Y += row_start * stride_y_row
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
+ if zero_centered_weight:
+ w += 1.0
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
+ if HAS_DY1:
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
+ if zero_centered_weight:
+ w1 += 1.0
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
+ if HAS_BIAS:
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
+ if HAS_DY1:
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
+ if HAS_B1:
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
+ row_end = min((row_block_id + 1) * rows_per_program, M)
+ for row in range(row_start, row_end):
+ # Load data to SRAM
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
+ if HAS_DY1:
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
+ if not IS_RMS_NORM:
+ mean = tl.load(Mean + row)
+ rstd = tl.load(Rstd + row)
+ # Compute dx
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
+ xhat = tl.where(mask, xhat, 0.0)
+ if RECOMPUTE_OUTPUT:
+ y = xhat * w + b if HAS_BIAS else xhat * w
+ tl.store(Y + cols, y, mask=mask)
+ wdy = w * dy
+ dw += dy * xhat
+ if HAS_BIAS:
+ db += dy
+ if HAS_DY1:
+ wdy += w1 * dy1
+ dw1 += dy1 * xhat
+ if HAS_B1:
+ db1 += dy1
+ if not IS_RMS_NORM:
+ c1 = tl.sum(xhat * wdy, axis=0) / N
+ c2 = tl.sum(wdy, axis=0) / N
+ dx = (wdy - (xhat * c1 + c2)) * rstd
+ else:
+ c1 = tl.sum(xhat * wdy, axis=0) / N
+ dx = (wdy - xhat * c1) * rstd
+ if HAS_DRESIDUAL:
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
+ dx += dres
+ # Write dx
+ if STORE_DRESIDUAL:
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
+ if HAS_DX1:
+ if HAS_DROPOUT:
+ keep_mask = (
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
+ )
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
+ else:
+ dx1 = dx
+ tl.store(DX1 + cols, dx1, mask=mask)
+ if HAS_DROPOUT:
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
+ if HAS_ROWSCALE:
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
+ dx *= rowscale
+ tl.store(DX + cols, dx, mask=mask)
+
+ X += stride_x_row
+ if HAS_DRESIDUAL:
+ DRESIDUAL += stride_dres_row
+ if STORE_DRESIDUAL:
+ DRESIDUAL_IN += stride_dres_in_row
+ if RECOMPUTE_OUTPUT:
+ Y += stride_y_row
+ DY += stride_dy_row
+ DX += stride_dx_row
+ if HAS_DY1:
+ DY1 += stride_dy1_row
+ if HAS_DX1:
+ DX1 += stride_dx1_row
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
+ if HAS_BIAS:
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
+ if HAS_DY1:
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
+ if HAS_B1:
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
+
+
+def _layer_norm_bwd(
+ dy,
+ x,
+ weight,
+ bias,
+ eps,
+ mean,
+ rstd,
+ dresidual=None,
+ dy1=None,
+ weight1=None,
+ bias1=None,
+ seeds=None,
+ dropout_p=0.0,
+ rowscale=None,
+ has_residual=False,
+ has_x1=False,
+ zero_centered_weight=False,
+ is_rms_norm=False,
+ x_dtype=None,
+ recompute_output=False,
+):
+ M, N = x.shape
+ assert x.stride(-1) == 1
+ assert dy.stride(-1) == 1
+ assert dy.shape == (M, N)
+ if dresidual is not None:
+ assert dresidual.stride(-1) == 1
+ assert dresidual.shape == (M, N)
+ assert weight.shape == (N,)
+ assert weight.stride(-1) == 1
+ if bias is not None:
+ assert bias.stride(-1) == 1
+ assert bias.shape == (N,)
+ if dy1 is not None:
+ assert weight1 is not None
+ assert dy1.shape == dy.shape
+ assert dy1.stride(-1) == 1
+ if weight1 is not None:
+ assert weight1.shape == (N,)
+ assert weight1.stride(-1) == 1
+ if bias1 is not None:
+ assert bias1.shape == (N,)
+ assert bias1.stride(-1) == 1
+ if seeds is not None:
+ assert seeds.is_contiguous()
+ assert seeds.shape == (M if not has_x1 else M * 2,)
+ if rowscale is not None:
+ assert rowscale.is_contiguous()
+ assert rowscale.shape == (M,)
+ # allocate output
+ dx = (
+ torch.empty_like(x)
+ if x_dtype is None
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
+ )
+ dresidual_in = (
+ torch.empty_like(x)
+ if has_residual
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
+ else None
+ )
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
+ if recompute_output:
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
+
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ if N > BLOCK_N:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+ # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
+ # latency of the gmem reads/writes, but will increase the time of summing up dw / db.
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
+ _db = (
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
+ if bias is not None
+ else None
+ )
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
+ rows_per_program = math.ceil(M / sm_count)
+ grid = (sm_count,)
+ with torch.cuda.device(x.device.index):
+ _layer_norm_bwd_kernel[grid](
+ x,
+ weight,
+ bias,
+ y,
+ dy,
+ dx,
+ _dw,
+ _db,
+ dresidual,
+ weight1,
+ dy1,
+ dx1,
+ _dw1,
+ _db1,
+ dresidual_in,
+ rowscale,
+ seeds,
+ mean,
+ rstd,
+ x.stride(0),
+ 0 if not recompute_output else y.stride(0),
+ dy.stride(0),
+ dx.stride(0),
+ dresidual.stride(0) if dresidual is not None else 0,
+ dy1.stride(0) if dy1 is not None else 0,
+ dx1.stride(0) if dx1 is not None else 0,
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
+ M,
+ N,
+ eps,
+ dropout_p,
+ zero_centered_weight,
+ rows_per_program,
+ is_rms_norm,
+ BLOCK_N,
+ dresidual is not None,
+ dresidual_in is not None,
+ bias is not None,
+ dropout_p > 0.0,
+ )
+ dw = _dw.sum(0).to(weight.dtype)
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
+ # Don't need to compute dresidual_in separately in this case
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
+ dresidual_in = dx
+ if has_x1 and dropout_p == 0.0:
+ dx1 = dx
+ return (
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
+ if not recompute_output
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
+ )
+
+
+class LayerNormFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ weight,
+ bias,
+ residual=None,
+ x1=None,
+ weight1=None,
+ bias1=None,
+ eps=1e-6,
+ dropout_p=0.0,
+ rowscale=None,
+ prenorm=False,
+ residual_in_fp32=False,
+ zero_centered_weight=False,
+ is_rms_norm=False,
+ return_dropout_mask=False,
+ out=None,
+ residual_out=None
+ ):
+ x_shape_og = x.shape
+ # Check for zero sequence length
+ if x.numel() == 0:
+ ctx.zero_seq_length = True
+ # Only save minimal required tensors for backward
+ # ctx.save_for_backward(weight, bias, weight1, bias1)
+ ctx.x_shape_og = x_shape_og
+ ctx.weight_shape = weight.shape
+ ctx.weight_dtype = weight.dtype
+ ctx.weight_device = weight.device
+
+ ctx.has_bias = bias is not None
+ ctx.bias_shape = bias.shape if bias is not None else None
+ ctx.bias_dtype = bias.dtype if bias is not None else None
+ ctx.bias_device = bias.device if bias is not None else None
+
+ ctx.has_weight1 = weight1 is not None
+ ctx.weight1_shape = weight1.shape if weight1 is not None else None
+ ctx.weight1_dtype = weight1.dtype if weight1 is not None else None
+ ctx.weight1_device = weight1.device if weight1 is not None else None
+
+ ctx.has_bias1 = bias1 is not None
+ ctx.bias1_shape = bias1.shape if bias1 is not None else None
+ ctx.bias1_dtype = bias1.dtype if bias1 is not None else None
+ ctx.bias1_device = bias1.device if bias1 is not None else None
+
+ ctx.has_residual = residual is not None
+ ctx.has_x1 = x1 is not None
+ ctx.dropout_p = dropout_p
+
+ # Handle output tensors with correct dtype
+ y = x # Preserve input tensor properties
+ y1 = torch.empty_like(x) if x1 is not None else None
+
+ # Only create residual_out if prenorm is True
+ residual_out = torch.empty(x.shape,
+ dtype=torch.float32 if residual_in_fp32 else x.dtype,
+ device=x.device) if prenorm else None
+
+ # Handle dropout masks
+ dropout_mask = None
+ dropout_mask1 = None
+ if return_dropout_mask:
+ dropout_mask = torch.empty_like(x, dtype=torch.uint8)
+ if x1 is not None:
+ dropout_mask1 = torch.empty_like(x, dtype=torch.uint8)
+
+ # Return based on configuration
+ if not return_dropout_mask:
+ if weight1 is None:
+ return y if not prenorm else (y, residual_out)
+ else:
+ return (y, y1) if not prenorm else (y, y1, residual_out)
+ else:
+ if weight1 is None:
+ return ((y, dropout_mask, dropout_mask1) if not prenorm
+ else (y, residual_out, dropout_mask, dropout_mask1))
+ else:
+ return ((y, y1, dropout_mask, dropout_mask1) if not prenorm
+ else (y, y1, residual_out, dropout_mask, dropout_mask1))
+
+ ctx.zero_seq_length = False
+ # reshape input data into 2D tensor
+ x = x.reshape(-1, x.shape[-1])
+ if x.stride(-1) != 1:
+ x = x.contiguous()
+ if residual is not None:
+ assert residual.shape == x_shape_og
+ residual = residual.reshape(-1, residual.shape[-1])
+ if residual.stride(-1) != 1:
+ residual = residual.contiguous()
+ if x1 is not None:
+ assert x1.shape == x_shape_og
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
+ x1 = x1.reshape(-1, x1.shape[-1])
+ if x1.stride(-1) != 1:
+ x1 = x1.contiguous()
+ weight = weight.contiguous()
+ if bias is not None:
+ bias = bias.contiguous()
+ if weight1 is not None:
+ weight1 = weight1.contiguous()
+ if bias1 is not None:
+ bias1 = bias1.contiguous()
+ if rowscale is not None:
+ rowscale = rowscale.reshape(-1).contiguous()
+ residual_dtype = (
+ residual.dtype
+ if residual is not None
+ else (torch.float32 if residual_in_fp32 else None)
+ )
+ if out is not None:
+ out = out.reshape(-1, out.shape[-1])
+ if residual_out is not None:
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
+ x,
+ weight,
+ bias,
+ eps,
+ residual,
+ x1,
+ weight1,
+ bias1,
+ dropout_p=dropout_p,
+ rowscale=rowscale,
+ residual_dtype=residual_dtype,
+ zero_centered_weight=zero_centered_weight,
+ is_rms_norm=is_rms_norm,
+ return_dropout_mask=return_dropout_mask,
+ out=out,
+ residual_out=residual_out
+ )
+ ctx.save_for_backward(
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
+ )
+ ctx.x_shape_og = x_shape_og
+ ctx.eps = eps
+ ctx.dropout_p = dropout_p
+ ctx.is_rms_norm = is_rms_norm
+ ctx.has_residual = residual is not None
+ ctx.has_x1 = x1 is not None
+ ctx.prenorm = prenorm
+ ctx.x_dtype = x.dtype
+ ctx.zero_centered_weight = zero_centered_weight
+ y = y.reshape(x_shape_og)
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
+ if not return_dropout_mask:
+ if weight1 is None:
+ return y if not prenorm else (y, residual_out)
+ else:
+ return (y, y1) if not prenorm else (y, y1, residual_out)
+ else:
+ if weight1 is None:
+ return (
+ (y, dropout_mask, dropout_mask1)
+ if not prenorm
+ else (y, residual_out, dropout_mask, dropout_mask1)
+ )
+ else:
+ return (
+ (y, y1, dropout_mask, dropout_mask1)
+ if not prenorm
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
+ )
+
+ @staticmethod
+ def backward(ctx, dy, *args):
+ if ctx.zero_seq_length:
+ return (
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device),
+ torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device),
+ torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None,
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None,
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None,
+ torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None,
+ torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
+ dy = dy.reshape(-1, dy.shape[-1])
+ if dy.stride(-1) != 1:
+ dy = dy.contiguous()
+ assert dy.shape == x.shape
+ if weight1 is not None:
+ dy1, args = args[0], args[1:]
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
+ if dy1.stride(-1) != 1:
+ dy1 = dy1.contiguous()
+ assert dy1.shape == x.shape
+ else:
+ dy1 = None
+ if ctx.prenorm:
+ dresidual = args[0]
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
+ if dresidual.stride(-1) != 1:
+ dresidual = dresidual.contiguous()
+ assert dresidual.shape == x.shape
+ else:
+ dresidual = None
+
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
+ dy,
+ x,
+ weight,
+ bias,
+ ctx.eps,
+ mean,
+ rstd,
+ dresidual,
+ dy1,
+ weight1,
+ bias1,
+ seeds,
+ ctx.dropout_p,
+ rowscale,
+ ctx.has_residual,
+ ctx.has_x1,
+ ctx.zero_centered_weight,
+ ctx.is_rms_norm,
+ x_dtype=ctx.x_dtype,
+ )
+ return (
+ dx.reshape(ctx.x_shape_og),
+ dw,
+ db,
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
+ dw1,
+ db1,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+def layer_norm_fn(
+ x,
+ weight,
+ bias,
+ residual=None,
+ x1=None,
+ weight1=None,
+ bias1=None,
+ eps=1e-6,
+ dropout_p=0.0,
+ rowscale=None,
+ prenorm=False,
+ residual_in_fp32=False,
+ zero_centered_weight=False,
+ is_rms_norm=False,
+ return_dropout_mask=False,
+ out=None,
+ residual_out=None
+):
+ return LayerNormFn.apply(
+ x,
+ weight,
+ bias,
+ residual,
+ x1,
+ weight1,
+ bias1,
+ eps,
+ dropout_p,
+ rowscale,
+ prenorm,
+ residual_in_fp32,
+ zero_centered_weight,
+ is_rms_norm,
+ return_dropout_mask,
+ out,
+ residual_out
+ )
+
+
+def rms_norm_fn(
+ x,
+ weight,
+ bias,
+ residual=None,
+ x1=None,
+ weight1=None,
+ bias1=None,
+ eps=1e-6,
+ dropout_p=0.0,
+ rowscale=None,
+ prenorm=False,
+ residual_in_fp32=False,
+ zero_centered_weight=False,
+ return_dropout_mask=False,
+ out=None,
+ residual_out=None
+):
+ return LayerNormFn.apply(
+ x,
+ weight,
+ bias,
+ residual,
+ x1,
+ weight1,
+ bias1,
+ eps,
+ dropout_p,
+ rowscale,
+ prenorm,
+ residual_in_fp32,
+ zero_centered_weight,
+ True,
+ return_dropout_mask,
+ out,
+ residual_out
+ )
+
+
+class RMSNorm(torch.nn.Module):
+
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
+ device=None, dtype=None):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.eps = eps
+ if dropout_p > 0.0:
+ self.drop = torch.nn.Dropout(dropout_p)
+ else:
+ self.drop = None
+ self.zero_centered_weight = zero_centered_weight
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
+ self.register_parameter("bias", None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if not self.zero_centered_weight:
+ torch.nn.init.ones_(self.weight)
+ else:
+ torch.nn.init.zeros_(self.weight)
+
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
+ return rms_norm_fn(
+ x,
+ self.weight,
+ self.bias,
+ residual=residual,
+ eps=self.eps,
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
+ prenorm=prenorm,
+ residual_in_fp32=residual_in_fp32,
+ zero_centered_weight=self.zero_centered_weight,
+ )
+
+
+class LayerNormLinearFn(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ x,
+ norm_weight,
+ norm_bias,
+ linear_weight,
+ linear_bias,
+ residual=None,
+ eps=1e-6,
+ prenorm=False,
+ residual_in_fp32=False,
+ is_rms_norm=False,
+ ):
+ x_shape_og = x.shape
+ # reshape input data into 2D tensor
+ x = x.reshape(-1, x.shape[-1])
+ if x.stride(-1) != 1:
+ x = x.contiguous()
+ if residual is not None:
+ assert residual.shape == x_shape_og
+ residual = residual.reshape(-1, residual.shape[-1])
+ if residual.stride(-1) != 1:
+ residual = residual.contiguous()
+ norm_weight = norm_weight.contiguous()
+ if norm_bias is not None:
+ norm_bias = norm_bias.contiguous()
+ residual_dtype = (
+ residual.dtype
+ if residual is not None
+ else (torch.float32 if residual_in_fp32 else None)
+ )
+ y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
+ x,
+ norm_weight,
+ norm_bias,
+ eps,
+ residual,
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"),
+ residual_dtype=residual_dtype,
+ is_rms_norm=is_rms_norm,
+ )
+ y = y.reshape(x_shape_og)
+ dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype
+ linear_weight = linear_weight.to(dtype)
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
+ # We don't store y, will be recomputed in the backward pass to save memory
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
+ ctx.x_shape_og = x_shape_og
+ ctx.eps = eps
+ ctx.is_rms_norm = is_rms_norm
+ ctx.has_residual = residual is not None
+ ctx.prenorm = prenorm
+ ctx.x_dtype = x.dtype
+ ctx.linear_bias_is_none = linear_bias is None
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dout, *args):
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
+ dout = dout.reshape(-1, dout.shape[-1])
+ dy = F.linear(dout, linear_weight.t())
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
+ if dy.stride(-1) != 1:
+ dy = dy.contiguous()
+ assert dy.shape == x.shape
+ if ctx.prenorm:
+ dresidual = args[0]
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
+ if dresidual.stride(-1) != 1:
+ dresidual = dresidual.contiguous()
+ assert dresidual.shape == x.shape
+ else:
+ dresidual = None
+ dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
+ dy,
+ x,
+ norm_weight,
+ norm_bias,
+ ctx.eps,
+ mean,
+ rstd,
+ dresidual=dresidual,
+ has_residual=ctx.has_residual,
+ is_rms_norm=ctx.is_rms_norm,
+ x_dtype=ctx.x_dtype,
+ recompute_output=True,
+ )
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
+ return (
+ dx.reshape(ctx.x_shape_og),
+ dnorm_weight,
+ dnorm_bias,
+ dlinear_weight,
+ dlinear_bias,
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+def layer_norm_linear_fn(
+ x,
+ norm_weight,
+ norm_bias,
+ linear_weight,
+ linear_bias,
+ residual=None,
+ eps=1e-6,
+ prenorm=False,
+ residual_in_fp32=False,
+ is_rms_norm=False,
+):
+ return LayerNormLinearFn.apply(
+ x,
+ norm_weight,
+ norm_bias,
+ linear_weight,
+ linear_bias,
+ residual,
+ eps,
+ prenorm,
+ residual_in_fp32,
+ is_rms_norm,
+ )
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/image_processor.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/image_processor.py
new file mode 100644
index 00000000..ec66dcdb
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/image_processor.py
@@ -0,0 +1,266 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
+from diffusers.configuration_utils import register_to_config
+
+class OmniGen2ImageProcessor(VaeImageProcessor):
+ """
+ Image processor for PixArt image resize and crop.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
+ resample (`str`, *optional*, defaults to `lanczos`):
+ Resampling filter to use when resizing the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image to [-1,1].
+ do_binarize (`bool`, *optional*, defaults to `False`):
+ Whether to binarize the image to 0/1.
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
+ Whether to convert the images to RGB format.
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
+ Whether to convert the images to grayscale format.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 16,
+ resample: str = "lanczos",
+ max_pixels: Optional[int] = None,
+ max_side_length: Optional[int] = None,
+ do_normalize: bool = True,
+ do_binarize: bool = False,
+ do_convert_grayscale: bool = False,
+ ):
+ super().__init__(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ resample=resample,
+ do_normalize=do_normalize,
+ do_binarize=do_binarize,
+ do_convert_grayscale=do_convert_grayscale,
+ )
+
+ self.max_pixels = max_pixels
+ self.max_side_length = max_side_length
+
+ def get_new_height_width(
+ self,
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ max_pixels: Optional[int] = None,
+ max_side_length: Optional[int] = None,
+ ) -> Tuple[int, int]:
+ r"""
+ Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
+
+ Args:
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
+ The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
+ should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
+ tensor, it should have shape `[batch, channels, height, width]`.
+ height (`Optional[int]`, *optional*, defaults to `None`):
+ The height of the preprocessed image. If `None`, the height of the `image` input will be used.
+ width (`Optional[int]`, *optional*, defaults to `None`):
+ The width of the preprocessed image. If `None`, the width of the `image` input will be used.
+
+ Returns:
+ `Tuple[int, int]`:
+ A tuple containing the height and width, both resized to the nearest integer multiple of
+ `vae_scale_factor`.
+ """
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[2]
+ else:
+ height = image.shape[1]
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[3]
+ else:
+ width = image.shape[2]
+
+ if max_side_length is None:
+ max_side_length = self.max_side_length
+
+ if max_pixels is None:
+ max_pixels = self.max_pixels
+
+ ratio = 1.0
+ if max_side_length is not None:
+ if height > width:
+ max_side_length_ratio = max_side_length / height
+ else:
+ max_side_length_ratio = max_side_length / width
+
+ cur_pixels = height * width
+ max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5
+ ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image
+
+ new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor
+ return new_height, new_width
+
+ def preprocess(
+ self,
+ image: PipelineImageInput,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ max_pixels: Optional[int] = None,
+ max_side_length: Optional[int] = None,
+ resize_mode: str = "default", # "default", "fill", "crop"
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
+ ) -> torch.Tensor:
+ """
+ Preprocess the image input.
+
+ Args:
+ image (`PipelineImageInput`):
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
+ supported formats.
+ height (`int`, *optional*):
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
+ height.
+ width (`int`, *optional*):
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
+ resize_mode (`str`, *optional*, defaults to `default`):
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
+ supported for PIL image input.
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
+
+ Returns:
+ `torch.Tensor`:
+ The preprocessed image.
+ """
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
+
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
+ if isinstance(image, torch.Tensor):
+ # if image is a pytorch tensor could have 2 possible shapes:
+ # 1. batch x height x width: we should insert the channel dimension at position 1
+ # 2. channel x height x width: we should insert batch dimension at position 0,
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
+ image = image.unsqueeze(1)
+ else:
+ # if it is a numpy array, it could have 2 possible shapes:
+ # 1. batch x height x width: insert channel dimension on last position
+ # 2. height x width x channel: insert batch dimension on first position
+ if image.shape[-1] == 1:
+ image = np.expand_dims(image, axis=0)
+ else:
+ image = np.expand_dims(image, axis=-1)
+
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
+ warnings.warn(
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
+ FutureWarning,
+ )
+ image = np.concatenate(image, axis=0)
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
+ warnings.warn(
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
+ FutureWarning,
+ )
+ image = torch.cat(image, axis=0)
+
+ if not is_valid_image_imagelist(image):
+ raise ValueError(
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
+ )
+ if not isinstance(image, list):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ if crops_coords is not None:
+ image = [i.crop(crops_coords) for i in image]
+ if self.config.do_resize:
+ height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
+ image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
+ if self.config.do_convert_rgb:
+ image = [self.convert_to_rgb(i) for i in image]
+ elif self.config.do_convert_grayscale:
+ image = [self.convert_to_grayscale(i) for i in image]
+ image = self.pil_to_numpy(image) # to np
+ image = self.numpy_to_pt(image) # to pt
+
+ elif isinstance(image[0], np.ndarray):
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
+
+ image = self.numpy_to_pt(image)
+
+ height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
+ if self.config.do_resize:
+ image = self.resize(image, height, width)
+
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
+
+ if self.config.do_convert_grayscale and image.ndim == 3:
+ image = image.unsqueeze(1)
+
+ channel = image.shape[1]
+ # don't need any preprocess if the image is latents
+ if channel == self.config.vae_latent_channels:
+ return image
+
+ height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
+ if self.config.do_resize:
+ image = self.resize(image, height, width)
+
+ # expected range [0,1], normalize to [-1,1]
+ do_normalize = self.config.do_normalize
+ if do_normalize and image.min() < 0:
+ warnings.warn(
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
+ FutureWarning,
+ )
+ do_normalize = False
+ if do_normalize:
+ image = self.normalize(image)
+
+ if self.config.do_binarize:
+ image = self.binarize(image)
+
+ return image
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py
new file mode 100644
index 00000000..45e509bc
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py
@@ -0,0 +1,728 @@
+"""
+OmniGen2 Diffusion Pipeline
+
+Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import math
+
+from PIL import Image
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from transformers import Qwen2_5_VLForConditionalGeneration
+
+from diffusers.models.autoencoders import AutoencoderKL
+from ...models.transformers import OmniGen2Transformer2DModel
+from ...models.transformers.repo import OmniGen2RotaryPosEmbed
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ is_torch_xla_available,
+ logging,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+
+from dataclasses import dataclass
+
+import PIL.Image
+
+from diffusers.utils import BaseOutput
+
+from ....src.pipelines.image_processor import OmniGen2ImageProcessor
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+@dataclass
+class FMPipelineOutput(BaseOutput):
+ """
+ Output class for OmniGen2 pipeline.
+
+ Args:
+ images (Union[List[PIL.Image.Image], np.ndarray]):
+ List of denoised PIL images of length `batch_size` or numpy array of shape
+ `(batch_size, height, width, num_channels)`. Contains the generated images.
+ """
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class OmniGen2Pipeline(DiffusionPipeline):
+ """
+ Pipeline for text-to-image generation using OmniGen2.
+
+ This pipeline implements a text-to-image generation model that uses:
+ - Qwen2.5-VL for text encoding
+ - A custom transformer architecture for image generation
+ - VAE for image encoding/decoding
+ - FlowMatchEulerDiscreteScheduler for noise scheduling
+
+ Args:
+ transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
+ vae (AutoencoderKL): The VAE model for image encoding/decoding.
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
+ text_encoder (Qwen2_5_VLModel): The text encoder model.
+ tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
+ """
+
+ model_cpu_offload_seq = "mllm->transformer->vae"
+
+ def __init__(
+ self,
+ transformer: OmniGen2Transformer2DModel,
+ vae: AutoencoderKL,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ mllm: Qwen2_5_VLForConditionalGeneration,
+ processor,
+ ) -> None:
+ """
+ Initialize the OmniGen2 pipeline.
+
+ Args:
+ transformer: The transformer model for image generation.
+ vae: The VAE model for image encoding/decoding.
+ scheduler: The scheduler for noise scheduling.
+ text_encoder: The text encoder model.
+ tokenizer: The tokenizer for text processing.
+ """
+ super().__init__()
+
+ self.register_modules(
+ transformer=transformer,
+ vae=vae,
+ scheduler=scheduler,
+ mllm=mllm,
+ processor=processor
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
+ self.default_sample_size = 128
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[torch.Generator],
+ latents: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ """
+ Prepare the initial latents for the diffusion process.
+
+ Args:
+ batch_size: The number of images to generate.
+ num_channels_latents: The number of channels in the latent space.
+ height: The height of the generated image.
+ width: The width of the generated image.
+ dtype: The data type of the latents.
+ device: The device to place the latents on.
+ generator: The random number generator to use.
+ latents: Optional pre-computed latents to use instead of random initialization.
+
+ Returns:
+ torch.FloatTensor: The prepared latents tensor.
+ """
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+ return latents
+
+ def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ Encode an image into the VAE latent space.
+
+ Args:
+ img: The input image tensor to encode.
+
+ Returns:
+ torch.FloatTensor: The encoded latent representation.
+ """
+ z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
+ if self.vae.config.shift_factor is not None:
+ z0 = z0 - self.vae.config.shift_factor
+ if self.vae.config.scaling_factor is not None:
+ z0 = z0 * self.vae.config.scaling_factor
+ z0 = z0.to(dtype=self.vae.dtype)
+ return z0
+
+ def prepare_image(
+ self,
+ images: Union[List[PIL.Image.Image], PIL.Image.Image],
+ batch_size: int,
+ num_images_per_prompt: int,
+ max_pixels: int,
+ max_side_length: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> List[Optional[torch.FloatTensor]]:
+ """
+ Prepare input images for processing by encoding them into the VAE latent space.
+
+ Args:
+ images: Single image or list of images to process.
+ batch_size: The number of images to generate per prompt.
+ num_images_per_prompt: The number of images to generate for each prompt.
+ device: The device to place the encoded latents on.
+ dtype: The data type of the encoded latents.
+
+ Returns:
+ List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
+ """
+ if batch_size == 1:
+ images = [images]
+ latents = []
+ for i, img in enumerate(images):
+ if img is not None and len(img) > 0:
+ ref_latents = []
+ for j, img_j in enumerate(img):
+ img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
+ ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
+ else:
+ ref_latents = None
+ for _ in range(num_images_per_prompt):
+ latents.append(ref_latents)
+
+ return latents
+
+ def _get_qwen2_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 256,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get prompt embeddings from the Qwen2 text encoder.
+
+ Args:
+ prompt: The prompt or list of prompts to encode.
+ device: The device to place the embeddings on. If None, uses the pipeline's device.
+ max_sequence_length: Maximum sequence length for tokenization.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - The prompt embeddings tensor
+ - The attention mask tensor
+
+ Raises:
+ Warning: If the input text is truncated due to sequence length limitations.
+ """
+ device = device or self._execution_device
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ # text_inputs = self.processor.tokenizer(
+ # prompt,
+ # padding="max_length",
+ # max_length=max_sequence_length,
+ # truncation=True,
+ # return_tensors="pt",
+ # )
+ text_inputs = self.processor.tokenizer(
+ prompt,
+ padding="longest",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids.to(device)
+ untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because Gemma can only handle sequences up to"
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds = self.mllm(
+ text_input_ids,
+ attention_mask=prompt_attention_mask,
+ output_hidden_states=True,
+ ).hidden_states[-1]
+
+ if self.mllm is not None:
+ dtype = self.mllm.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def _apply_chat_template(self, prompt: str):
+ prompt = [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant that generates high-quality images based on user instructions.",
+ },
+ {"role": "user", "content": prompt},
+ ]
+ prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
+ return prompt
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 256,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ Lumina-T2I, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
+ max_sequence_length (`int`, defaults to `256`):
+ Maximum sequence length to use for the prompt.
+ """
+ device = device or self._execution_device
+
+ if prompt is not None:
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [self._apply_chat_template(_prompt) for _prompt in prompt]
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length
+ )
+
+ batch_size, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
+
+ # Get negative embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
+
+ # Normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ batch_size, seq_len, _ = negative_prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
+ batch_size * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def text_guidance_scale(self):
+ return self._text_guidance_scale
+
+ @property
+ def image_guidance_scale(self):
+ return self._image_guidance_scale
+
+ @property
+ def cfg_range(self):
+ return self._cfg_range
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.LongTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
+ max_sequence_length: Optional[int] = None,
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
+ input_images: Optional[List[PIL.Image.Image]] = None,
+ num_images_per_prompt: int = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ max_pixels: int = 1024 * 1024,
+ max_input_image_side_length: int = 1024,
+ align_res: bool = True,
+ num_inference_steps: int = 28,
+ text_guidance_scale: float = 4.0,
+ image_guidance_scale: float = 1.0,
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ timesteps: List[int] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ verbose: bool = False,
+ step_func=None,
+ ):
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ self._text_guidance_scale = text_guidance_scale
+ self._image_guidance_scale = image_guidance_scale
+ self._cfg_range = cfg_range
+ self._attention_kwargs = attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ self.text_guidance_scale > 1.0,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ )
+
+ dtype = self.vae.dtype
+ # 3. Prepare control image
+ ref_latents = self.prepare_image(
+ images=input_images,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ max_pixels=max_pixels,
+ max_side_length=max_input_image_side_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if input_images is None:
+ input_images = []
+
+ if len(input_images) == 1 and align_res:
+ width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
+ ori_width, ori_height = width, height
+ else:
+ ori_width, ori_height = width, height
+
+ cur_pixels = height * width
+ ratio = (max_pixels / cur_pixels) ** 0.5
+ ratio = min(ratio, 1.0)
+
+ height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
+
+ if len(input_images) == 0:
+ self._image_guidance_scale = 1
+
+ # 4. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
+ self.transformer.config.axes_dim_rope,
+ self.transformer.config.axes_lens,
+ theta=10000,
+ )
+
+ image = self.processing(
+ latents=latents,
+ ref_latents=ref_latents,
+ prompt_embeds=prompt_embeds,
+ freqs_cis=freqs_cis,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ device=device,
+ dtype=dtype,
+ verbose=verbose,
+ step_func=step_func,
+ )
+
+ image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return image
+ else:
+ return FMPipelineOutput(images=image)
+
+ def processing(
+ self,
+ latents,
+ ref_latents,
+ prompt_embeds,
+ freqs_cis,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ num_inference_steps,
+ timesteps,
+ device,
+ dtype,
+ verbose,
+ step_func=None
+ ):
+ batch_size = latents.shape[0]
+
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ num_tokens=latents.shape[-2] * latents.shape[-1]
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ model_pred = self.predict(
+ t=t,
+ latents=latents,
+ prompt_embeds=prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=prompt_attention_mask,
+ ref_image_hidden_states=ref_latents,
+ )
+ text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
+ image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
+
+ if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
+ model_pred_ref = self.predict(
+ t=t,
+ latents=latents,
+ prompt_embeds=negative_prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ ref_image_hidden_states=ref_latents,
+ )
+
+ if image_guidance_scale != 1:
+ model_pred_uncond = self.predict(
+ t=t,
+ latents=latents,
+ prompt_embeds=negative_prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ ref_image_hidden_states=None,
+ )
+ else:
+ model_pred_uncond = torch.zeros_like(model_pred)
+
+ model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
+ text_guidance_scale * (model_pred - model_pred_ref)
+ elif text_guidance_scale > 1.0:
+ model_pred_uncond = self.predict(
+ t=t,
+ latents=latents,
+ prompt_embeds=negative_prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ ref_image_hidden_states=None,
+ )
+ model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
+
+ latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
+
+ latents = latents.to(dtype=dtype)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if step_func is not None:
+ step_func(i, self._num_timesteps)
+
+ latents = latents.to(dtype=dtype)
+ if self.vae.config.scaling_factor is not None:
+ latents = latents / self.vae.config.scaling_factor
+ if self.vae.config.shift_factor is not None:
+ latents = latents + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ return image
+
+ def predict(
+ self,
+ t,
+ latents,
+ prompt_embeds,
+ freqs_cis,
+ prompt_attention_mask,
+ ref_image_hidden_states,
+ ):
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ batch_size, num_channels_latents, height, width = latents.shape
+
+ optional_kwargs = {}
+ if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
+ optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
+
+ model_pred = self.transformer(
+ latents,
+ timestep,
+ prompt_embeds,
+ freqs_cis,
+ prompt_attention_mask,
+ **optional_kwargs
+ )
+ return model_pred
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2_chat.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2_chat.py
new file mode 100644
index 00000000..43d88402
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2_chat.py
@@ -0,0 +1,830 @@
+"""
+OmniGen2 Diffusion Pipeline
+
+Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import math
+
+from PIL import Image
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from transformers import Qwen2_5_VLForConditionalGeneration
+
+from diffusers.models.autoencoders import AutoencoderKL
+from ...models.transformers import OmniGen2Transformer2DModel
+from ...models.transformers.repo import OmniGen2RotaryPosEmbed
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ is_torch_xla_available,
+ logging,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+
+from dataclasses import dataclass
+
+import PIL.Image
+
+from diffusers.utils import BaseOutput
+
+from src.pipelines.image_processor import OmniGen2ImageProcessor
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+@dataclass
+class OmniGen2PipelineOutput(BaseOutput):
+ """
+ Output class for OmniGen2 pipeline.
+
+ Args:
+ images (Union[List[PIL.Image.Image], np.ndarray]):
+ List of denoised PIL images of length `batch_size` or numpy array of shape
+ `(batch_size, height, width, num_channels)`. Contains the generated images.
+ """
+ text: str
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class OmniGen2ChatPipeline(DiffusionPipeline):
+ """
+ Pipeline for text-to-image generation using OmniGen2.
+
+ This pipeline implements a text-to-image generation model that uses:
+ - Qwen2.5-VL for text encoding
+ - A custom transformer architecture for image generation
+ - VAE for image encoding/decoding
+ - FlowMatchEulerDiscreteScheduler for noise scheduling
+
+ Args:
+ transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
+ vae (AutoencoderKL): The VAE model for image encoding/decoding.
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
+ text_encoder (Qwen2_5_VLModel): The text encoder model.
+ tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
+ """
+
+ model_cpu_offload_seq = "mllm->transformer->vae"
+ def __init__(
+ self,
+ transformer: OmniGen2Transformer2DModel,
+ vae: AutoencoderKL,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ mllm: Qwen2_5_VLForConditionalGeneration,
+ processor,
+ ) -> None:
+ """
+ Initialize the OmniGen2 pipeline.
+
+ Args:
+ transformer: The transformer model for image generation.
+ vae: The VAE model for image encoding/decoding.
+ scheduler: The scheduler for noise scheduling.
+ text_encoder: The text encoder model.
+ tokenizer: The tokenizer for text processing.
+ """
+ super().__init__()
+
+ self.register_modules(
+ transformer=transformer,
+ vae=vae,
+ scheduler=scheduler,
+ mllm=mllm,
+ processor=processor
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
+ self.default_sample_size = 128
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[torch.Generator],
+ latents: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ """
+ Prepare the initial latents for the diffusion process.
+
+ Args:
+ batch_size: The number of images to generate.
+ num_channels_latents: The number of channels in the latent space.
+ height: The height of the generated image.
+ width: The width of the generated image.
+ dtype: The data type of the latents.
+ device: The device to place the latents on.
+ generator: The random number generator to use.
+ latents: Optional pre-computed latents to use instead of random initialization.
+
+ Returns:
+ torch.FloatTensor: The prepared latents tensor.
+ """
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+ return latents
+
+ def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ Encode an image into the VAE latent space.
+
+ Args:
+ img: The input image tensor to encode.
+
+ Returns:
+ torch.FloatTensor: The encoded latent representation.
+ """
+ z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
+ if self.vae.config.shift_factor is not None:
+ z0 = z0 - self.vae.config.shift_factor
+ if self.vae.config.scaling_factor is not None:
+ z0 = z0 * self.vae.config.scaling_factor
+ z0 = z0.to(dtype=self.vae.dtype)
+ return z0
+
+ def prepare_image(
+ self,
+ images: Union[List[PIL.Image.Image], PIL.Image.Image],
+ batch_size: int,
+ num_images_per_prompt: int,
+ max_pixels: int,
+ max_side_length: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> List[Optional[torch.FloatTensor]]:
+ """
+ Prepare input images for processing by encoding them into the VAE latent space.
+
+ Args:
+ images: Single image or list of images to process.
+ batch_size: The number of images to generate per prompt.
+ num_images_per_prompt: The number of images to generate for each prompt.
+ device: The device to place the encoded latents on.
+ dtype: The data type of the encoded latents.
+
+ Returns:
+ List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
+ """
+ if batch_size == 1:
+ images = [images]
+ latents = []
+ for i, img in enumerate(images):
+ if img is not None and len(img) > 0:
+ ref_latents = []
+ for j, img_j in enumerate(img):
+ img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
+ ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
+ else:
+ ref_latents = None
+ for _ in range(num_images_per_prompt):
+ latents.append(ref_latents)
+
+ return latents
+
+ def _apply_chat_template(self, prompt: str, images: List = None):
+ if images is not None:
+ prompt = "".join(
+ [
+ f"
: <|vision_start|><|image_pad|><|vision_end|>"
+ for i in range(1, len(images) + 1)
+ ]
+ ) + prompt
+ prompt = f"<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
+ return prompt
+
+ def _get_qwen2_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ input_images = None,
+ device: Optional[torch.device] = None,
+ use_only_text_hidden_states: bool = True,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get prompt embeddings from the Qwen2 text encoder.
+
+ Args:
+ prompt: The prompt or list of prompts to encode.
+ device: The device to place the embeddings on. If None, uses the pipeline's device.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - The prompt embeddings tensor
+ - The attention mask tensor
+
+ Raises:
+ Warning: If the input text is truncated due to sequence length limitations.
+ """
+ device = device or self._execution_device
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ inputs = self.processor(
+ text=prompt,
+ images=input_images,
+ videos=None,
+ padding=True,
+ return_tensors="pt",
+ )
+ inputs = inputs.to(device)
+
+ prompt_embeds = self.mllm(
+ **inputs,
+ output_hidden_states=True,
+ ).hidden_states[-1]
+
+ text_input_ids = inputs.input_ids
+ text_mask = inputs.attention_mask
+ if use_only_text_hidden_states:
+ mask = text_input_ids != self.mllm.config.image_token_id
+ mask = mask & text_mask
+ mask = mask.bool()
+
+ text_l = mask.sum(dim=-1)
+ max_l = text_l.max()
+ text_batch_size = prompt_embeds.size(0)
+ new_prompt_embeds = torch.zeros((text_batch_size, max_l, prompt_embeds.size(-1)), device=prompt_embeds.device, dtype=prompt_embeds.dtype)
+ new_text_mask = torch.zeros((text_batch_size, max_l), dtype=text_mask.dtype, device=text_mask.device)
+ for i in range(text_batch_size):
+ new_prompt_embeds[i, :text_l[i]] = prompt_embeds[i, mask[i]]
+ new_text_mask[i, :text_l[i]] = 1
+
+ prompt_embeds = new_prompt_embeds
+ text_mask = new_text_mask
+
+ prompt_embeds = prompt_embeds.to(dtype=self.mllm.dtype, device=device)
+ return prompt_embeds, text_mask
+
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ input_images: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 256,
+ use_text_encoder_penultimate_layer_feats: bool = False
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ Lumina-T2I, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
+ max_sequence_length (`int`, defaults to `256`):
+ Maximum sequence length to use for the prompt.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
+ prompt=prompt,
+ input_images=input_images,
+ device=device,
+ )
+
+ batch_size, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
+
+ # Get negative embeddings for classifier free guidance
+ negative_prompt_embeds, negative_prompt_attention_mask = None, None
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
+
+ # Normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ )
+
+ batch_size, seq_len, _ = negative_prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
+ batch_size * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def text_guidance_scale(self):
+ return self._text_guidance_scale
+
+ @property
+ def image_guidance_scale(self):
+ return self._image_guidance_scale
+
+ @property
+ def cfg_range(self):
+ return self._cfg_range
+
+ def prepare_inputs_for_text_generation(self, prompts, input_images, device):
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ ori_padding_side = self.processor.tokenizer.padding_side
+ self.processor.tokenizer.padding_side = "left"
+ inputs = self.processor(
+ text=prompts,
+ images=input_images,
+ videos=None,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+ self.processor.tokenizer.padding_side = ori_padding_side
+ return inputs
+
+ def generate_text(self, prompt, input_images):
+ inputs = self.prepare_inputs_for_text_generation(
+ prompt, input_images, self.mllm.device
+ )
+ generated_ids = self.mllm.generate(
+ **inputs,
+ tokenizer=self.processor.tokenizer,
+ max_new_tokens=256,
+ stop_strings=["<|im_end|>", "<|img|>", "<|endoftext|>"],
+ ) # stop_words=[151643, 151645, 151665]
+ generated_ids_trimmed = [
+ out_ids[len(in_ids) :]
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ output_texts = self.processor.batch_decode(
+ generated_ids_trimmed,
+ # skip_special_tokens=True,
+ skip_special_tokens=False,
+ clean_up_tokenization_spaces=False,
+ )
+ return output_texts
+
+ def generate_image(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.LongTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
+ use_text_encoder_penultimate_layer_feats: bool = False,
+ max_sequence_length: Optional[int] = None,
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
+ input_images: Optional[List[PIL.Image.Image]] = None,
+ num_images_per_prompt: int = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ max_pixels: int = 1024 * 1024,
+ max_input_image_side_length: int = 1024,
+ align_res: bool = True,
+ num_inference_steps: int = 28,
+ text_guidance_scale: float = 4.0,
+ image_guidance_scale: float = 1.0,
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ timesteps: List[int] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ verbose: bool = False,
+ step_func=None,
+ ):
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ self._text_guidance_scale = text_guidance_scale
+ self._image_guidance_scale = image_guidance_scale
+ self._cfg_range = cfg_range
+ self._attention_kwargs = attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input promptb
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ input_images,
+ self.text_guidance_scale > 1.0,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats
+ )
+
+ dtype = self.vae.dtype
+ # 3. Prepare control image
+ ref_latents = self.prepare_image(
+ images=input_images,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ max_pixels=max_pixels,
+ max_side_length=max_input_image_side_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if input_images is None:
+ input_images = []
+
+ if len(input_images) == 1 and align_res:
+ width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
+ ori_width, ori_height = width, height
+ else:
+ ori_width, ori_height = width, height
+
+ cur_pixels = height * width
+ ratio = (max_pixels / cur_pixels) ** 0.5
+ ratio = min(ratio, 1.0)
+
+ height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
+
+ if len(input_images) == 0:
+ self._image_guidance_scale = 1
+
+ # 4. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
+ self.transformer.config.axes_dim_rope,
+ self.transformer.config.axes_lens,
+ theta=10000,
+ )
+
+ image = self.processing(
+ latents=latents,
+ ref_latents=ref_latents,
+ prompt_embeds=prompt_embeds,
+ freqs_cis=freqs_cis,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ device=device,
+ dtype=dtype,
+ verbose=verbose,
+ step_func=step_func,
+ )
+
+ image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+ return image
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.LongTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
+ use_text_encoder_penultimate_layer_feats: bool = False,
+ max_sequence_length: Optional[int] = None,
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
+ input_images: Optional[List[PIL.Image.Image]] = None,
+ num_images_per_prompt: int = 1,
+ height: Optional[int] = 1024,
+ width: Optional[int] = 1024,
+ max_pixels: Optional[int] = 1024 * 1024,
+ max_input_image_side_length: int = 1024,
+ align_res: bool = True,
+ num_inference_steps: int = 28,
+ text_guidance_scale: float = 4.0,
+ image_guidance_scale: float = 1.0,
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ timesteps: List[int] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ verbose: bool = False,
+ step_func=None,
+ ):
+ assert isinstance(prompt, str), "prompt must be a string since chat mode only support one prompt per turn"
+
+ # input_images = self.preprocess_images(input_images, max_input_image_size)
+ prompt = self._apply_chat_template(prompt, input_images)
+ generated_text = self.generate_text(prompt, input_images)[0]
+
+ images = None
+ if generated_text.startswith("<|img|>"):
+ #TODO: reuse the hidden state when generate text instead of re-generating
+ prompt = prompt + generated_text.split("<|img|>")[0]
+ images = self.generate_image(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats,
+ max_sequence_length=max_sequence_length,
+ input_images=input_images,
+ num_images_per_prompt=num_images_per_prompt,
+ height=height,
+ width=width,
+ max_pixels=max_pixels,
+ max_input_image_side_length=max_input_image_side_length,
+ align_res=align_res,
+ num_inference_steps=num_inference_steps,
+ text_guidance_scale=text_guidance_scale,
+ image_guidance_scale=image_guidance_scale,
+ cfg_range=cfg_range,
+ timesteps=timesteps,
+ generator=generator,
+ latents=latents,
+ return_dict=False,
+ verbose=verbose,
+ step_func=step_func,
+ )
+
+ generated_text = generated_text.replace("<|im_end|>", "")
+ if not return_dict:
+ return generated_text, images
+ else:
+ return OmniGen2PipelineOutput(text=generated_text, images=images)
+
+ def processing(
+ self,
+ latents,
+ ref_latents,
+ prompt_embeds,
+ freqs_cis,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ num_inference_steps,
+ timesteps,
+ device,
+ dtype,
+ verbose,
+ step_func=None
+ ):
+ batch_size = latents.shape[0]
+
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ num_tokens=latents.shape[-2] * latents.shape[-1]
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ model_pred = self.predict(
+ t=t,
+ latents=latents,
+ prompt_embeds=prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=prompt_attention_mask,
+ ref_image_hidden_states=ref_latents,
+ )
+
+ text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
+ image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
+ if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
+ model_pred_ref = self.predict(
+ t=t,
+ latents=latents,
+ prompt_embeds=negative_prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ ref_image_hidden_states=ref_latents,
+ )
+
+ if image_guidance_scale != 1:
+ model_pred_uncond = self.predict(
+ t=t,
+ latents=latents,
+ prompt_embeds=negative_prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ ref_image_hidden_states=None,
+ )
+ else:
+ model_pred_uncond = torch.zeros_like(model_pred)
+
+ model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
+ text_guidance_scale * (model_pred - model_pred_ref)
+ elif text_guidance_scale > 1.0:
+ model_pred_uncond = self.predict(
+ t=t,
+ latents=latents,
+ prompt_embeds=negative_prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ ref_image_hidden_states=None,
+ )
+ model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
+
+ latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
+
+ latents = latents.to(dtype=dtype)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if step_func is not None:
+ step_func(i, self._num_timesteps)
+
+ latents = latents.to(dtype=dtype)
+ if self.vae.config.scaling_factor is not None:
+ latents = latents / self.vae.config.scaling_factor
+ if self.vae.config.shift_factor is not None:
+ latents = latents + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ return image
+
+ def predict(
+ self,
+ t,
+ latents,
+ prompt_embeds,
+ freqs_cis,
+ prompt_attention_mask,
+ ref_image_hidden_states,
+ ):
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ batch_size, num_channels_latents, height, width = latents.shape
+
+ optional_kwargs = {}
+ if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
+ optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
+
+ model_pred = self.transformer(
+ latents,
+ timestep,
+ prompt_embeds,
+ freqs_cis,
+ prompt_attention_mask,
+ **optional_kwargs
+ )
+ return model_pred
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/pipeline_utils.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/pipeline_utils.py
new file mode 100644
index 00000000..de31ff4e
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/pipeline_utils.py
@@ -0,0 +1,62 @@
+import torch
+
+
+def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
+ """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
+ :param pipeline:
+ :param prompt:
+ :param negative_prompt:
+ :param device:
+ :return:
+ """
+ max_length = pipeline.tokenizer.model_max_length
+
+ # simple way to determine length of tokens
+ # count_prompt = len(prompt.split(" "))
+ # count_negative_prompt = len(negative_prompt.split(" "))
+
+ # create the tensor based on which prompt is longer
+ # if count_prompt >= count_negative_prompt:
+ input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device)
+ # input_ids = pipeline.tokenizer(prompt, padding="max_length",
+ # max_length=pipeline.tokenizer.model_max_length,
+ # truncation=True,
+ # return_tensors="pt",).input_ids.to(device)
+ shape_max_length = input_ids.shape[-1]
+
+ if negative_prompt is not None:
+ negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length",
+ max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
+
+ # else:
+ # negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
+ # shape_max_length = negative_ids.shape[-1]
+ # input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
+ # max_length=shape_max_length).input_ids.to(device)
+
+ concat_embeds = []
+ neg_embeds = []
+ for i in range(0, shape_max_length, max_length):
+ if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
+ attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device)
+ else:
+ attention_mask = None
+ concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length],
+ attention_mask=attention_mask)[0])
+
+ if negative_prompt is not None:
+ if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
+ attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device)
+ else:
+ attention_mask = None
+ neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length],
+ attention_mask=attention_mask)[0])
+
+ concat_embeds = torch.cat(concat_embeds, dim=1)
+
+ if negative_prompt is not None:
+ neg_embeds = torch.cat(neg_embeds, dim=1)
+ else:
+ neg_embeds = None
+
+ return concat_embeds, neg_embeds
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/schedulers/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/schedulers/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_dpmsolver_multistep.py b/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_dpmsolver_multistep.py
new file mode 100644
index 00000000..1c85e930
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_dpmsolver_multistep.py
@@ -0,0 +1,1052 @@
+# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import deprecate, is_scipy_available
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+if is_scipy_available():
+ import scipy.stats
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.Tensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.Tensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ solver_order (`int`, defaults to 2):
+ The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
+ sampling, and `solver_order=3` for unconditional sampling.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
+ `algorithm_type="dpmsolver++"`.
+ algorithm_type (`str`, defaults to `dpmsolver++`):
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
+ paper, and the `dpmsolver++` type implements the algorithms in the
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
+ solver_type (`str`, defaults to `midpoint`):
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
+ lower_order_final (`bool`, defaults to `True`):
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
+ euler_at_final (`bool`, defaults to `False`):
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
+ steps, but sometimes may result in blurring.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
+ the sigmas are determined according to a sequence of noise levels {σi}.
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
+ use_lu_lambdas (`bool`, *optional*, defaults to `False`):
+ Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
+ the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
+ `lambda(t)`.
+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
+ flow_shift (`float`, *optional*, defaults to 1.0):
+ The shift value for the timestep schedule for flow matching.
+ final_sigmas_type (`str`, defaults to `"zero"`):
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ lambda_min_clipped (`float`, defaults to `-inf`):
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
+ cosine (`squaredcos_cap_v2`) noise schedule.
+ variance_type (`str`, *optional*):
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
+ contains the predicted Gaussian variance.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "dpmsolver++",
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ euler_at_final: bool = False,
+ final_sigmas_type: str = 'zero',
+ dynamic_time_shift: bool = True
+ ):
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
+
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # Currently we only support VP-type noise schedule
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # settings for DPM-Solver
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
+ if algorithm_type == "deis":
+ self.register_to_config(algorithm_type="dpmsolver++")
+ else:
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
+
+ if solver_type not in ["midpoint", "heun"]:
+ if solver_type in ["logrho", "bh1", "bh2"]:
+ self.register_to_config(solver_type="midpoint")
+ else:
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
+
+ # if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
+ # raise ValueError(
+ # f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
+ # )
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.model_outputs = [None] * solver_order
+ self.lower_order_nums = 0
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int = None,
+ device: Union[str, torch.device] = None,
+ timesteps: Optional[List[int]] = None,
+ num_tokens: Optional[int] = None
+ ):
+ if timesteps is None:
+ self.num_inference_steps = num_inference_steps
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
+ if self.config.dynamic_time_shift and num_tokens is not None:
+ m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
+ timesteps = timesteps / (m - m * timesteps + timesteps)
+
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
+ sigmas = torch.cat([1 - timesteps, torch.zeros(1, device=timesteps.device)])
+
+ self.sigmas = sigmas
+ self.timesteps = timesteps
+
+ self.num_inference_steps = len(timesteps)
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+
+ # add an index counter for schedulers that allow duplicated timesteps
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ def _sigma_to_alpha_sigma_t(self, sigma):
+ alpha_t = 1 - sigma
+ sigma_t = sigma
+
+ return alpha_t, sigma_t
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
+ """Constructs the noise schedule of Lu et al. (2022)."""
+
+ lambda_min: float = in_lambdas[-1].item()
+ lambda_max: float = in_lambdas[0].item()
+
+ rho = 1.0 # 1.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = lambda_min ** (1 / rho)
+ max_inv_rho = lambda_max ** (1 / rho)
+ lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return lambdas
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
+ """Constructs an exponential noise schedule."""
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
+ return sigmas
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
+ def _convert_to_beta(
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
+ ) -> torch.Tensor:
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ sigmas = np.array(
+ [
+ sigma_min + (ppf * (sigma_max - sigma_min))
+ for ppf in [
+ scipy.stats.beta.ppf(timestep, alpha, beta)
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
+ ]
+ ]
+ )
+ return sigmas
+
+ def convert_model_output(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
+ integral of the data prediction model.
+
+
+
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
+ prediction and data prediction models.
+
+
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The converted model output.
+ """
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError("missing `sample` as a required keyward argument")
+ if timestep is not None:
+ deprecate(
+ "timesteps",
+ "1.0.0",
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
+ if self.config.prediction_type == "epsilon":
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
+ if self.config.variance_type in ["learned", "learned_range"]:
+ model_output = model_output[:, :3]
+ sigma = self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ sigma = self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ elif self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample + sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+
+ # DPM-Solver needs to solve an integral of the noise prediction model.
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
+ if self.config.prediction_type == "epsilon":
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
+ if self.config.variance_type in ["learned", "learned_range"]:
+ epsilon = model_output[:, :3]
+ else:
+ epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ sigma = self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ epsilon = (sample - alpha_t * model_output) / sigma_t
+ elif self.config.prediction_type == "v_prediction":
+ sigma = self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ epsilon = alpha_t * model_output + sigma_t * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the DPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ sigma = self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
+ x0_pred = self._threshold_sample(x0_pred)
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
+
+ return epsilon
+
+ def dpm_solver_first_order_update(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the first-order DPMSolver (equivalent to DDIM).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
+ if sample is None:
+ if len(args) > 2:
+ sample = args[2]
+ else:
+ raise ValueError(" missing `sample` as a required keyward argument")
+ if timestep is not None:
+ deprecate(
+ "timesteps",
+ "1.0.0",
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
+
+ h = lambda_t - lambda_s
+ if self.config.algorithm_type == "dpmsolver++":
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "dpmsolver":
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ x_t = (
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+ elif self.config.algorithm_type == "sde-dpmsolver":
+ assert noise is not None
+ x_t = (
+ (alpha_t / alpha_s) * sample
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
+ )
+ return x_t
+
+ def multistep_dpm_solver_second_order_update(
+ self,
+ model_output_list: List[torch.Tensor],
+ *args,
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the second-order multistep DPMSolver.
+
+ Args:
+ model_output_list (`List[torch.Tensor]`):
+ The direct outputs from learned diffusion model at current and latter timesteps.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
+ if sample is None:
+ if len(args) > 2:
+ sample = args[2]
+ else:
+ raise ValueError(" missing `sample` as a required keyward argument")
+ if timestep_list is not None:
+ deprecate(
+ "timestep_list",
+ "1.0.0",
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma_t, sigma_s0, sigma_s1 = (
+ self.sigmas[self.step_index + 1],
+ self.sigmas[self.step_index],
+ self.sigmas[self.step_index - 1],
+ )
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
+
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ )
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+ elif self.config.algorithm_type == "sde-dpmsolver":
+ assert noise is not None
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * (torch.exp(h) - 1.0)) * D1
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_order_update(
+ self,
+ model_output_list: List[torch.Tensor],
+ *args,
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the third-order multistep DPMSolver.
+
+ Args:
+ model_output_list (`List[torch.Tensor]`):
+ The direct outputs from learned diffusion model at current and latter timesteps.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
+ if sample is None:
+ if len(args) > 2:
+ sample = args[2]
+ else:
+ raise ValueError(" missing`sample` as a required keyward argument")
+ if timestep_list is not None:
+ deprecate(
+ "timestep_list",
+ "1.0.0",
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
+ self.sigmas[self.step_index + 1],
+ self.sigmas[self.step_index],
+ self.sigmas[self.step_index - 1],
+ self.sigmas[self.step_index - 2],
+ )
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
+
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
+
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
+ r0, r1 = h_0 / h, h_1 / h
+ D0 = m0
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
+ )
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+ return x_t
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ index_candidates = (schedule_timesteps == timestep).nonzero()
+
+ if len(index_candidates) == 0:
+ step_index = len(self.timesteps) - 1
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ elif len(index_candidates) > 1:
+ step_index = index_candidates[1].item()
+ else:
+ step_index = index_candidates[0].item()
+
+ return step_index
+
+ def _init_step_index(self, timestep):
+ """
+ Initialize the step_index counter for the scheduler.
+ """
+
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: Union[int, torch.Tensor],
+ sample: torch.Tensor,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
+ the multistep DPMSolver.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`LEdits++`].
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Improve numerical stability for small number of steps
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
+ self.config.euler_at_final
+ or (self.config.lower_order_final and len(self.timesteps) < 15)
+ or self.config.final_sigmas_type == "zero"
+ )
+ lower_order_second = (
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+
+ model_output = self.convert_model_output(model_output, sample=sample)
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.model_outputs[-1] = model_output
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
+ noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
+ )
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
+ noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
+ else:
+ noise = None
+
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
+ prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
+ prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
+ else:
+ prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ # Cast sample back to expected dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
+ if self.begin_index is None:
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
+ elif self.step_index is not None:
+ # add_noise is called after first denoising step (for inpainting)
+ step_indices = [self.step_index] * timesteps.shape[0]
+ else:
+ # add noise is called before first denoising step to create initial latent(img2img)
+ step_indices = [self.begin_index] * timesteps.shape[0]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_flow_match_euler_discrete.py b/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_flow_match_euler_discrete.py
new file mode 100644
index 00000000..91513476
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_flow_match_euler_discrete.py
@@ -0,0 +1,229 @@
+# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Euler scheduler.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ shift (`float`, defaults to 1.0):
+ The shift value for the timestep schedule.
+ """
+
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ dynamic_time_shift: bool = True
+ ):
+ timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
+
+ self.timesteps = timesteps
+
+ self._step_index = None
+ self._begin_index = None
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self._timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ # def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
+ # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int = None,
+ device: Union[str, torch.device] = None,
+ timesteps: Optional[List[float]] = None,
+ num_tokens: Optional[int] = None
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+
+ if timesteps is None:
+ self.num_inference_steps = num_inference_steps
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
+ if self.config.dynamic_time_shift and num_tokens is not None:
+ m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
+ timesteps = timesteps / (m - m * timesteps + timesteps)
+
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
+ _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
+
+ self.timesteps = timesteps
+ self._timesteps = _timesteps
+ self._step_index = None
+ self._begin_index = None
+
+ def _init_step_index(self, timestep):
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ s_churn (`float`):
+ s_tmin (`float`):
+ s_tmax (`float`):
+ s_noise (`float`, defaults to 1.0):
+ Scaling factor for noise added to the sample.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
+ tuple.
+
+ Returns:
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+ t = self._timesteps[self.step_index]
+ t_next = self._timesteps[self.step_index + 1]
+
+ prev_sample = sample + (t_next - t) * model_output
+
+ # Cast sample back to model compatible dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/utils/__init__.py b/extensions_built_in/diffusion_models/omnigen2/src/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/utils/img_util.py b/extensions_built_in/diffusion_models/omnigen2/src/utils/img_util.py
new file mode 100644
index 00000000..d9026a63
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/utils/img_util.py
@@ -0,0 +1,31 @@
+from typing import List
+
+from PIL import Image
+
+import torch
+from torchvision.transforms.functional import to_pil_image
+
+def resize_image(image, max_pixels, img_scale_num):
+ width, height = image.size
+ cur_pixels = height * width
+ ratio = (max_pixels / cur_pixels) ** 0.5
+ ratio = min(ratio, 1.0) # do not upscale input image
+
+ new_height, new_width = int(height * ratio) // img_scale_num * img_scale_num, int(width * ratio) // img_scale_num * img_scale_num
+
+ image = image.resize((new_width, new_height), resample=Image.BICUBIC)
+ return image
+
+def create_collage(images: List[torch.Tensor]) -> Image.Image:
+ """Create a horizontal collage from a list of images."""
+ max_height = max(img.shape[-2] for img in images)
+ total_width = sum(img.shape[-1] for img in images)
+ canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
+
+ current_x = 0
+ for img in images:
+ h, w = img.shape[-2:]
+ canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
+ current_x += w
+
+ return to_pil_image(canvas)
\ No newline at end of file
diff --git a/extensions_built_in/diffusion_models/omnigen2/src/utils/import_utils.py b/extensions_built_in/diffusion_models/omnigen2/src/utils/import_utils.py
new file mode 100644
index 00000000..dc946d77
--- /dev/null
+++ b/extensions_built_in/diffusion_models/omnigen2/src/utils/import_utils.py
@@ -0,0 +1,46 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Import utilities: Utilities related to imports and our lazy inits.
+"""
+
+import importlib.util
+import sys
+
+# The package importlib_metadata is in a different place, depending on the python version.
+if sys.version_info < (3, 8):
+ import importlib_metadata
+else:
+ import importlib.metadata as importlib_metadata
+
+def _is_package_available(pkg_name: str):
+ pkg_exists = importlib.util.find_spec(pkg_name) is not None
+ pkg_version = "N/A"
+
+ if pkg_exists:
+ try:
+ pkg_version = importlib_metadata.version(pkg_name)
+ except (ImportError, importlib_metadata.PackageNotFoundError):
+ pkg_exists = False
+
+ return pkg_exists, pkg_version
+
+_triton_available, _triton_version = _is_package_available("triton")
+_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
+
+def is_triton_available():
+ return _triton_available
+
+def is_flash_attn_available():
+ return _flash_attn_available
\ No newline at end of file
diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py
index 0c7d7562..9078fbbb 100644
--- a/toolkit/data_transfer_object/data_loader.py
+++ b/toolkit/data_transfer_object/data_loader.py
@@ -91,6 +91,8 @@ class FileItemDTO(
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
+ img = exif_transpose(Image.open(self.path))
+ w, h = img.size
else:
img = exif_transpose(Image.open(self.path))
w, h = img.size
diff --git a/toolkit/extension.py b/toolkit/extension.py
index 8d1f38e5..f4f10e9d 100644
--- a/toolkit/extension.py
+++ b/toolkit/extension.py
@@ -34,7 +34,7 @@ def get_all_extensions() -> List[Extension]:
for sub_dir in extension_folders:
extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir)
for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
- try:
+ # try:
# Import the module
module = importlib.import_module(f"{sub_dir}.{name}")
# Get the value of the AI_TOOLKIT_EXTENSIONS variable
@@ -43,8 +43,8 @@ def get_all_extensions() -> List[Extension]:
if isinstance(extensions, list):
# Iterate over the list and add the classes to the main list
all_extension_classes.extend(extensions)
- except ImportError as e:
- print(f"Failed to import the {name} module. Error: {str(e)}")
+ # except ImportError as e:
+ # print(f"Failed to import the {name} module. Error: {str(e)}")
return all_extension_classes
diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts
index e84bc4a0..2f4e235c 100644
--- a/ui/src/app/jobs/new/options.ts
+++ b/ui/src/app/jobs/new/options.ts
@@ -170,6 +170,19 @@ export const modelArchs: ModelArch[] = [
},
disableSections: ['model.quantize', 'train.timestep_type'],
},
+ {
+ name: 'omnigen2',
+ label: 'OmniGen2',
+ defaults: {
+ // default updates when [selected, unselected] in the UI
+ 'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath],
+ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
+ 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
+ 'config.process[0].model.quantize': [false, false],
+ 'config.process[0].model.quantize_te': [true, false],
+ },
+ disableSections: ['network.conv'],
+ },
].sort((a, b) => {
// Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
diff --git a/version.py b/version.py
index 0db14ed3..60504c97 100644
--- a/version.py
+++ b/version.py
@@ -1 +1 @@
-VERSION = "0.3.1"
\ No newline at end of file
+VERSION = "0.3.2"
\ No newline at end of file