Added support for finetuning OmniGen2.

This commit is contained in:
Jaret Burkett
2025-06-25 13:58:16 -06:00
parent 5e733764aa
commit 19ea8ecc38
28 changed files with 6405 additions and 5 deletions

View File

@@ -0,0 +1,94 @@
---
job: extension
config:
# this name will be the folder and filename name
name: "my_first_omnigen2_lora_v1"
process:
- type: 'sd_trainer'
# root folder to save training sessions/samples/weights
training_folder: "output"
# uncomment to see performance stats in the terminal every N steps
# performance_log_every: 1000
device: cuda:0
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
# trigger_word: "p3r5on"
network:
type: "lora"
linear: 16
linear_alpha: 16
save:
dtype: float16 # precision to save
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
# images will automatically be resized and bucketed into the resolution specified
# on windows, escape back slashes with another backslash so
# "C:\\path\\to\\images\\folder"
- folder_path: "/path/to/images/folder"
caption_ext: "txt"
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
shuffle_tokens: false # shuffle caption order, split by commas
cache_latents_to_disk: true # leave this true unless you know what you're doing
resolution: [ 512, 768, 1024 ] # omnigen2 should work with multiple resolutions
train:
batch_size: 1
steps: 3000 # total number of steps to train 500 - 4000 is a good range
gradient_accumulation: 1
train_unet: true
train_text_encoder: false # probably won't work with omnigen2
gradient_checkpointing: true # need the on unless you have a ton of vram
noise_scheduler: "flowmatch" # for training only
optimizer: "adamw8bit"
lr: 1e-4
timestep_type: 'sigmoid' # sigmoid, linear, shift
# uncomment this to skip the pre training sample
# skip_first_sample: true
# uncomment to completely disable sampling
# disable_sampling: true
# ema will smooth out learning, but could slow it down.
# ema_config:
# use_ema: true
# ema_decay: 0.99
# will probably need this if gpu supports it for omnigen2, other dtypes may not work correctly
dtype: bf16
model:
name_or_path: "OmniGen2/OmniGen2
arch: "omnigen2"
quantize_te: true # quantize_only te
# quantize: true # quantize transformer
sample:
sampler: "flowmatch" # must match train.noise_scheduler
sample_every: 250 # sample every this many steps
width: 1024
height: 1024
prompts:
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
- "woman with red hair, playing chess at the park, bomb going off in the background"
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
- "a bear building a log cabin in the snow covered mountains"
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
- "hipster man with a beard, building a chair, in a wood shop"
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
- "a man holding a sign that says, 'this is a sign'"
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
neg: "" # negative prompt, optional
seed: 42
walk_seed: true
guidance_scale: 4
sample_steps: 25
# you can add any additional meta info here. [name] is replaced with config name at top
meta:
name: "[name]"
version: '1.0'

View File

@@ -1,8 +1,9 @@
from .chroma import ChromaModel from .chroma import ChromaModel
from .hidream import HidreamModel from .hidream import HidreamModel
from .f_light import FLiteModel from .f_light import FLiteModel
from .omnigen2 import OmniGen2Model
AI_TOOLKIT_MODELS = [ AI_TOOLKIT_MODELS = [
# put a list of models here # put a list of models here
ChromaModel, HidreamModel, FLiteModel ChromaModel, HidreamModel, FLiteModel, OmniGen2Model
] ]

View File

@@ -0,0 +1,327 @@
import inspect
import os
from typing import TYPE_CHECKING, List, Optional
import einops
import torch
import yaml
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.models.base_model import BaseModel
from diffusers import AutoencoderKL
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze
from toolkit.util.quantize import quantize, get_qtype
from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
from .src.models.transformers import OmniGen2Transformer2DModel
from .src.models.transformers.repo import OmniGen2RotaryPosEmbed
from .src.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler
from transformers import CLIPProcessor, Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
scheduler_config = {
"num_train_timesteps": 1000
}
BASE_MODEL_PATH = "OmniGen2/OmniGen2"
class OmniGen2Model(BaseModel):
arch = "omnigen2"
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(
device,
model_config,
dtype,
custom_pipeline,
noise_scheduler,
**kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['OmniGen2Transformer2DModel']
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
return 16
def load_model(self):
dtype = self.torch_dtype
# HiDream-ai/HiDream-I1-Full
self.print_and_status_update("Loading OmniGen2 model")
# will be updated if we detect a existing checkpoint in training folder
model_path = self.model_config.name_or_path
extras_path = self.model_config.extras_name_or_path
scheduler = OmniGen2Model.get_train_scheduler()
self.print_and_status_update("Loading Qwen2.5 VL")
processor = CLIPProcessor.from_pretrained(
extras_path,
subfolder="processor",
use_fast=True
)
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
extras_path,
subfolder="mllm",
torch_dtype=torch.bfloat16
)
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Qwen2.5 VL model")
quantization_type = get_qtype(self.model_config.qtype_te)
quantize(mllm, weights=quantization_type)
freeze(mllm)
if self.low_vram:
# unload it for now
mllm.to('cpu')
flush()
self.print_and_status_update("Loading transformer")
transformer = OmniGen2Transformer2DModel.from_pretrained(
model_path,
subfolder="transformer",
torch_dtype=torch.bfloat16
)
if not self.low_vram:
transformer.to(self.device_torch, dtype=dtype)
if self.model_config.quantize:
self.print_and_status_update("Quantizing transformer")
quantization_type = get_qtype(self.model_config.qtype)
quantize(transformer, weights=quantization_type)
freeze(transformer)
if self.low_vram:
# unload it for now
transformer.to('cpu')
flush()
self.print_and_status_update("Loading vae")
vae = AutoencoderKL.from_pretrained(
extras_path,
subfolder="vae",
torch_dtype=torch.bfloat16
).to(self.device_torch, dtype=dtype)
flush()
self.print_and_status_update("Loading Qwen2.5 VLProcessor")
flush()
if self.low_vram:
self.print_and_status_update("Moving everything to device")
# move it all back
transformer.to(self.device_torch, dtype=dtype)
vae.to(self.device_torch, dtype=dtype)
mllm.to(self.device_torch, dtype=dtype)
# set to eval mode
# transformer.eval()
vae.eval()
mllm.eval()
mllm.requires_grad_(False)
pipe: OmniGen2Pipeline = OmniGen2Pipeline(
transformer=transformer,
vae=vae,
scheduler=scheduler,
mllm=mllm,
processor=processor,
)
# pipe: OmniGen2Pipeline = OmniGen2Pipeline.from_pretrained(
# model_path,
# transformer=transformer,
# vae=vae,
# scheduler=scheduler,
# mllm=mllm,
# trust_remote_code=True,
# )
# processor = pipe.processor
flush()
text_encoder_list = [mllm]
tokenizer_list = [processor]
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder_list # list of text encoders
self.tokenizer = tokenizer_list # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
transformer.config.axes_dim_rope,
transformer.config.axes_lens,
theta=10000,
)
self.print_and_status_update("Model Loaded")
def get_generation_pipeline(self):
scheduler = OmniFlowMatchEuler(
dynamic_time_shift=True,
num_train_timesteps=1000
)
pipeline: OmniGen2Pipeline = OmniGen2Pipeline(
transformer=self.model,
vae=self.vae,
scheduler=scheduler,
mllm=self.text_encoder[0],
processor=self.tokenizer[0],
)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: OmniGen2Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
prompt_attention_mask=conditional_embeds.attention_mask,
negative_prompt_embeds=unconditional_embeds.text_embeds,
negative_prompt_attention_mask=unconditional_embeds.attention_mask,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
text_guidance_scale=gen_config.guidance_scale,
image_guidance_scale=1.0, # reference image guidance scale. Add this for controls
latents=gen_config.latents,
generator=generator,
**extra
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
# optional_kwargs = {}
# if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()):
# optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
timesteps = timestep / 1000 # convert to 0 to 1 scale
# timestep for model starts at 0 instead of 1. So we need to reverse them
timestep = 1 - timesteps
model_pred = self.model(
latent_model_input,
timestep,
text_embeddings.text_embeds,
self.freqs_cis,
text_embeddings.attention_mask,
ref_image_hidden_states=None, # todo add ref latent ability
)
return model_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt]
self.text_encoder_to(self.device_torch, dtype=self.torch_dtype)
max_sequence_length = 256
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
prompt = prompt,
do_classifier_free_guidance=False,
device=self.device_torch,
max_sequence_length=max_sequence_length,
)
pe = PromptEmbeds(prompt_embeds)
pe.attention_mask = prompt_attention_mask
return pe
def get_model_has_grad(self):
# return from a weight if it has grad
return False
def get_te_has_grad(self):
# assume no one wants to finetune 4 text encoders.
return False
def save_model(self, output_path, meta, save_dtype):
# only save the transformer
transformer: OmniGen2Transformer2DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, 'transformer'),
safe_serialization=True,
)
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
batch = kwargs.get('batch')
# return (noise - batch.latents).detach()
return (batch.latents - noise).detach()
def get_transformer_block_names(self) -> Optional[List[str]]:
# omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers.
# lets do all but image refiner until we add it
return ['noise_refiner', 'context_refiner', 'layers']
# return ['layers']
def convert_lora_weights_before_save(self, state_dict):
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("transformer.", "diffusion_model.")
new_sd[new_key] = value
return new_sd
def convert_lora_weights_before_load(self, state_dict):
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd
def get_base_model_version(self):
return "omnigen2"

View File

@@ -0,0 +1,357 @@
"""
OmniGen2 Attention Processor Module
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import warnings
import math
from typing import Optional, Tuple, Dict, Any
import torch
import torch.nn.functional as F
from einops import repeat
from ..utils.import_utils import is_flash_attn_available
if is_flash_attn_available():
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
else:
warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance")
from diffusers.models.attention_processor import Attention
from .embeddings import apply_rotary_emb
class OmniGen2AttnProcessorFlash2Varlen:
"""
Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
This processor implements:
- Flash attention with variable length sequences
- Rotary position embeddings (RoPE)
- Query-Key normalization
- Proportional attention scaling
Args:
None
"""
def __init__(self) -> None:
"""Initialize the attention processor."""
if not is_flash_attn_available():
raise ImportError(
"OmniGen2AttnProcessorFlash2Varlen requires flash_attn. "
"Please install flash_attn."
)
def _upad_input(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
num_heads: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
"""
Unpad the input tensors for flash attention.
Args:
query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
attention_mask: Attention mask tensor of shape (batch_size, seq_len)
query_length: Length of the query sequence
num_heads: Number of attention heads
Returns:
Tuple containing:
- Unpadded query tensor
- Unpadded key tensor
- Unpadded value tensor
- Query indices
- Tuple of cumulative sequence lengths for query and key
- Tuple of maximum sequence lengths for query and key
"""
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Helper function to get unpadding data from attention mask."""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
# Unpad key and value layers
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
indices_k,
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
indices_k,
)
# Handle different query length cases
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
indices_k,
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
)
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
base_sequence_length: Optional[int] = None,
) -> torch.Tensor:
"""
Process attention computation with flash attention.
Args:
attn: Attention module
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
encoder_hidden_states: Encoder hidden states tensor
attention_mask: Optional attention mask tensor
image_rotary_emb: Optional rotary embeddings for image tokens
base_sequence_length: Optional base sequence length for proportional attention
Returns:
torch.Tensor: Processed hidden states after attention computation
"""
batch_size, sequence_length, _ = hidden_states.shape
# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query_dim = query.shape[-1]
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype
# Get key-value heads
kv_heads = inner_dim // head_dim
# Reshape tensors for attention computation
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, kv_heads, head_dim)
value = value.view(batch_size, -1, kv_heads, head_dim)
# Apply Query-Key normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply Rotary Position Embeddings
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
query, key = query.to(dtype), key.to(dtype)
# Calculate attention scale
if base_sequence_length is not None:
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
else:
softmax_scale = attn.scale
# Unpad input for flash attention
(
query_states,
key_states,
value_states,
indices_q,
cu_seq_lens,
max_seq_lens,
) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
# Handle different number of heads
if kv_heads < attn.heads:
key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
# Apply flash attention
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=0.0,
causal=False,
softmax_scale=softmax_scale,
)
# Pad output and apply final transformations
hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
hidden_states = hidden_states.flatten(-2)
hidden_states = hidden_states.type_as(query)
# Apply output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class OmniGen2AttnProcessor:
"""
Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
This processor is optimized for PyTorch 2.0 and implements:
- Flash attention with variable length sequences
- Rotary position embeddings (RoPE)
- Query-Key normalization
- Proportional attention scaling
Args:
None
Raises:
ImportError: If PyTorch version is less than 2.0
"""
def __init__(self) -> None:
"""Initialize the attention processor."""
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. "
"Please upgrade PyTorch to version 2.0 or later."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
base_sequence_length: Optional[int] = None,
) -> torch.Tensor:
"""
Process attention computation with flash attention.
Args:
attn: Attention module
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
encoder_hidden_states: Encoder hidden states tensor
attention_mask: Optional attention mask tensor
image_rotary_emb: Optional rotary embeddings for image tokens
base_sequence_length: Optional base sequence length for proportional attention
Returns:
torch.Tensor: Processed hidden states after attention computation
"""
batch_size, sequence_length, _ = hidden_states.shape
# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query_dim = query.shape[-1]
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype
# Get key-value heads
kv_heads = inner_dim // head_dim
# Reshape tensors for attention computation
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, kv_heads, head_dim)
value = value.view(batch_size, -1, kv_heads, head_dim)
# Apply Query-Key normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply Rotary Position Embeddings
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
query, key = query.to(dtype), key.to(dtype)
# Calculate attention scale
if base_sequence_length is not None:
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
else:
softmax_scale = attn.scale
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
if attention_mask is not None:
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=softmax_scale
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.type_as(query)
# Apply output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states

View File

@@ -0,0 +1,126 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from diffusers.models.activations import get_activation
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
self.initialize_weights()
def initialize_weights(self):
nn.init.normal_(self.linear_1.weight, std=0.02)
nn.init.zeros_(self.linear_1.bias)
nn.init.normal_(self.linear_2.weight, std=0.02)
nn.init.zeros_(self.linear_2.bias)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen and CogView4
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
# x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)

View File

@@ -0,0 +1,3 @@
from .transformer_omnigen2 import OmniGen2Transformer2DModel
__all__ = ["OmniGen2Transformer2DModel"]

View File

@@ -0,0 +1,218 @@
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional, Tuple
import torch
import torch.nn as nn
from diffusers.models.embeddings import Timesteps
from ..embeddings import TimestepEmbedding
from ...utils.import_utils import is_flash_attn_available, is_triton_available
if is_triton_available():
from ...ops.triton.layer_norm import RMSNorm
else:
from torch.nn import RMSNorm
warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance")
if is_flash_attn_available():
from flash_attn.ops.activations import swiglu
else:
from .components import swiglu
warnings.warn("Cannot import flash_attn, install flash_attn to use fused SwiGLU for better performance")
# try:
# from flash_attn.ops.activations import swiglu as fused_swiglu
# FUSEDSWIGLU_AVALIBLE = True
# except ImportError:
# FUSEDSWIGLU_AVALIBLE = False
# warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
class LuminaRMSNormZero(nn.Module):
"""
Norm layer adaptive RMS normalization zero.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(
self,
embedding_dim: int,
norm_eps: float,
norm_elementwise_affine: bool,
):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(
min(embedding_dim, 1024),
4 * embedding_dim,
bias=True,
)
self.norm = RMSNorm(embedding_dim, eps=norm_eps)
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
return x, gate_msa, scale_mlp, gate_mlp
class LuminaLayerNormContinuous(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
out_dim: Optional[int] = None,
):
super().__init__()
# AdaLN
self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
self.linear_2 = None
if out_dim is not None:
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
def forward(
self,
x: torch.Tensor,
conditioning_embedding: torch.Tensor,
) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
scale = emb
x = self.norm(x) * (1 + scale)[:, None, :]
if self.linear_2 is not None:
x = self.linear_2(x)
return x
class LuminaFeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
hidden_size (`int`):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
of this value.
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
dimension. Defaults to None.
"""
def __init__(
self,
dim: int,
inner_dim: int,
multiple_of: Optional[int] = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
self.swiglu = swiglu
# custom hidden_size factor multiplier
if ffn_dim_multiplier is not None:
inner_dim = int(ffn_dim_multiplier * inner_dim)
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
self.linear_1 = nn.Linear(
dim,
inner_dim,
bias=False,
)
self.linear_2 = nn.Linear(
inner_dim,
dim,
bias=False,
)
self.linear_3 = nn.Linear(
dim,
inner_dim,
bias=False,
)
def forward(self, x):
h1, h2 = self.linear_1(x), self.linear_3(x)
return self.linear_2(self.swiglu(h1, h2))
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(
self,
hidden_size: int = 4096,
text_feat_dim: int = 2048,
frequency_embedding_size: int = 256,
norm_eps: float = 1e-5,
timestep_scale: float = 1.0,
) -> None:
super().__init__()
self.time_proj = Timesteps(
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
)
self.timestep_embedder = TimestepEmbedding(
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
)
self.caption_embedder = nn.Sequential(
RMSNorm(text_feat_dim, eps=norm_eps),
nn.Linear(text_feat_dim, hidden_size, bias=True),
)
self._initialize_weights()
def _initialize_weights(self):
nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
nn.init.zeros_(self.caption_embedder[1].bias)
def forward(
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
time_embed = self.timestep_embedder(timestep_proj)
caption_embed = self.caption_embedder(text_hidden_states)
return time_embed, caption_embed

View File

@@ -0,0 +1,4 @@
import torch.nn.functional as F
def swiglu(x, y):
return F.silu(x.float(), inplace=False).to(x.dtype) * y

View File

@@ -0,0 +1,135 @@
from typing import List, Tuple
import torch
import torch.nn as nn
from einops import repeat
from diffusers.models.embeddings import get_1d_rotary_pos_embed
class OmniGen2RotaryPosEmbed(nn.Module):
def __init__(self, theta: int,
axes_dim: Tuple[int, int, int],
axes_lens: Tuple[int, int, int] = (300, 512, 512),
patch_size: int = 2):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.axes_lens = axes_lens
self.patch_size = patch_size
@staticmethod
def get_freqs_cis(axes_dim: Tuple[int, int, int],
axes_lens: Tuple[int, int, int],
theta: int) -> List[torch.Tensor]:
freqs_cis = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
freqs_cis.append(emb)
return freqs_cis
def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
device = ids.device
if ids.device.type == "mps":
ids = ids.to("cpu")
result = []
for i in range(len(self.axes_dim)):
freqs = freqs_cis[i].to(ids.device)
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
return torch.cat(result, dim=-1).to(device)
def forward(
self,
freqs_cis,
attention_mask,
l_effective_ref_img_len,
l_effective_img_len,
ref_img_sizes,
img_sizes,
device
):
batch_size = len(attention_mask)
p = self.patch_size
encoder_seq_len = attention_mask.shape[1]
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
max_seq_len = int(max(seq_lengths))
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
max_img_len = max(l_effective_img_len)
# Create position IDs
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
cap_seq_len = int(cap_seq_len)
seq_len = int(seq_len)
# add text position ids
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
pe_shift = cap_seq_len
pe_shift_len = cap_seq_len
if ref_img_sizes[i] is not None:
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
H, W = ref_img_size
ref_H_tokens, ref_W_tokens = H // p, W // p
assert ref_H_tokens * ref_W_tokens == ref_img_len
# add image position ids
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
pe_shift += max(ref_H_tokens, ref_W_tokens)
pe_shift_len += ref_img_len
H, W = img_sizes[i]
H_tokens, W_tokens = H // p, W // p
assert H_tokens * W_tokens == l_effective_img_len[i]
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
assert pe_shift_len + l_effective_img_len[i] == seq_len
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
# Get combined rotary embeddings
freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
# create separate rotary embeddings for captions and images
cap_freqs_cis = torch.zeros(
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
)
ref_img_freqs_cis = torch.zeros(
batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
)
img_freqs_cis = torch.zeros(
batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
)
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
cap_seq_len = int(cap_seq_len)
sum_ref_img_len = int(sum(ref_img_len))
img_len = int(img_len)
seq_len = int(seq_len)
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
ref_img_freqs_cis[i, :sum_ref_img_len] = freqs_cis[i, cap_seq_len:cap_seq_len + sum_ref_img_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum_ref_img_len:cap_seq_len + sum_ref_img_len + img_len]
return (
cap_freqs_cis,
ref_img_freqs_cis,
img_freqs_cis,
freqs_cis,
l_effective_cap_len,
seq_lengths,
)

View File

@@ -0,0 +1,621 @@
import warnings
import itertools
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from einops import rearrange
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.attention_processor import Attention
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen, OmniGen2AttnProcessor
from .repo import OmniGen2RotaryPosEmbed
from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding
from ...utils.import_utils import is_triton_available, is_flash_attn_available
if is_triton_available():
from ...ops.triton.layer_norm import RMSNorm
else:
from torch.nn import RMSNorm
logger = logging.get_logger(__name__)
class OmniGen2TransformerBlock(nn.Module):
"""
Transformer block for OmniGen2 model.
This block implements a transformer layer with:
- Multi-head attention with flash attention
- Feed-forward network with SwiGLU activation
- RMS normalization
- Optional modulation for conditional generation
Args:
dim: Dimension of the input and output tensors
num_attention_heads: Number of attention heads
num_kv_heads: Number of key-value heads
multiple_of: Multiple of which the hidden dimension should be
ffn_dim_multiplier: Multiplier for the feed-forward network dimension
norm_eps: Epsilon value for normalization layers
modulation: Whether to use modulation for conditional generation
use_fused_rms_norm: Whether to use fused RMS normalization
use_fused_swiglu: Whether to use fused SwiGLU activation
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
num_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
modulation: bool = True,
) -> None:
"""Initialize the transformer block."""
super().__init__()
self.head_dim = dim // num_attention_heads
self.modulation = modulation
try:
processor = OmniGen2AttnProcessorFlash2Varlen()
except ImportError:
processor = OmniGen2AttnProcessor()
# Initialize attention layer
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
qk_norm="rms_norm",
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
out_bias=False,
processor=processor,
)
# Initialize feed-forward network
self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier
)
# Initialize normalization layers
if modulation:
self.norm1 = LuminaRMSNormZero(
embedding_dim=dim,
norm_eps=norm_eps,
norm_elementwise_affine=True
)
else:
self.norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
self.initialize_weights()
def initialize_weights(self) -> None:
"""
Initialize the weights of the transformer block.
Uses Xavier uniform initialization for linear layers and zero initialization for biases.
"""
nn.init.xavier_uniform_(self.attn.to_q.weight)
nn.init.xavier_uniform_(self.attn.to_k.weight)
nn.init.xavier_uniform_(self.attn.to_v.weight)
nn.init.xavier_uniform_(self.attn.to_out[0].weight)
nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
if self.modulation:
nn.init.zeros_(self.norm1.linear.weight)
nn.init.zeros_(self.norm1.linear.bias)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass of the transformer block.
Args:
hidden_states: Input hidden states tensor
attention_mask: Attention mask tensor
image_rotary_emb: Rotary embeddings for image tokens
temb: Optional timestep embedding tensor
Returns:
torch.Tensor: Output hidden states after transformer block processing
"""
import time
if self.modulation:
if temb is None:
raise ValueError("temb must be provided when modulation is enabled")
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
return hidden_states
class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
OmniGen2 Transformer 2D Model.
A transformer-based diffusion model for image generation with:
- Patch-based image processing
- Rotary position embeddings
- Multi-head attention
- Conditional generation support
Args:
patch_size: Size of image patches
in_channels: Number of input channels
out_channels: Number of output channels (defaults to in_channels)
hidden_size: Size of hidden layers
num_layers: Number of transformer layers
num_refiner_layers: Number of refiner layers
num_attention_heads: Number of attention heads
num_kv_heads: Number of key-value heads
multiple_of: Multiple of which the hidden dimension should be
ffn_dim_multiplier: Multiplier for feed-forward network dimension
norm_eps: Epsilon value for normalization layers
axes_dim_rope: Dimensions for rotary position embeddings
axes_lens: Lengths for rotary position embeddings
text_feat_dim: Dimension of text features
timestep_scale: Scale factor for timestep embeddings
use_fused_rms_norm: Whether to use fused RMS normalization
use_fused_swiglu: Whether to use fused SwiGLU activation
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["Omnigen2TransformerBlock"]
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
@register_to_config
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 2304,
num_layers: int = 26,
num_refiner_layers: int = 2,
num_attention_heads: int = 24,
num_kv_heads: int = 8,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
axes_lens: Tuple[int, int, int] = (300, 512, 512),
text_feat_dim: int = 1024,
timestep_scale: float = 1.0
) -> None:
"""Initialize the OmniGen2 transformer model."""
super().__init__()
# Validate configuration
if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
raise ValueError(
f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
)
self.out_channels = out_channels or in_channels
# Initialize embeddings
self.rope_embedder = OmniGen2RotaryPosEmbed(
theta=10000,
axes_dim=axes_dim_rope,
axes_lens=axes_lens,
patch_size=patch_size,
)
self.x_embedder = nn.Linear(
in_features=patch_size * patch_size * in_channels,
out_features=hidden_size,
)
self.ref_image_patch_embedder = nn.Linear(
in_features=patch_size * patch_size * in_channels,
out_features=hidden_size,
)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size,
text_feat_dim=text_feat_dim,
norm_eps=norm_eps,
timestep_scale=timestep_scale
)
# Initialize transformer blocks
self.noise_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=True
)
for _ in range(num_refiner_layers)
])
self.ref_image_refiner = nn.ModuleList([
OmniGen2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=True
)
for _ in range(num_refiner_layers)
])
self.context_refiner = nn.ModuleList(
[
OmniGen2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=False
)
for _ in range(num_refiner_layers)
]
)
# 3. Transformer blocks
self.layers = nn.ModuleList(
[
OmniGen2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=True
)
for _ in range(num_layers)
]
)
# 4. Output norm & projection
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
bias=True,
out_dim=patch_size * patch_size * self.out_channels
)
# Add learnable embeddings to distinguish different images
self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
self.gradient_checkpointing = False
self.initialize_weights()
def initialize_weights(self) -> None:
"""
Initialize the weights of the model.
Uses Xavier uniform initialization for linear layers.
"""
nn.init.xavier_uniform_(self.x_embedder.weight)
nn.init.constant_(self.x_embedder.bias, 0.0)
nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
nn.init.zeros_(self.norm_out.linear_1.weight)
nn.init.zeros_(self.norm_out.linear_1.bias)
nn.init.zeros_(self.norm_out.linear_2.weight)
nn.init.zeros_(self.norm_out.linear_2.bias)
nn.init.normal_(self.image_index_embedding, std=0.02)
def img_patch_embed_and_refine(
self,
hidden_states,
ref_image_hidden_states,
padded_img_mask,
padded_ref_img_mask,
noise_rotary_emb,
ref_img_rotary_emb,
l_effective_ref_img_len,
l_effective_img_len,
temb
):
batch_size = len(hidden_states)
max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
hidden_states = self.x_embedder(hidden_states)
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
for i in range(batch_size):
shift = 0
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
num_ref_images = len(flat_l_effective_ref_img_len)
max_ref_img_len = max(flat_l_effective_ref_img_len)
batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
# sequence of ref imgs to batch
idx = 0
for i in range(batch_size):
shift = 0
for ref_img_len in l_effective_ref_img_len[i]:
batch_ref_img_mask[idx, :ref_img_len] = True
batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
batch_temb[idx] = temb[i]
shift += ref_img_len
idx += 1
# refine ref imgs separately
for layer in self.ref_image_refiner:
batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
# batch of ref imgs to sequence
idx = 0
for i in range(batch_size):
shift = 0
for ref_img_len in l_effective_ref_img_len[i]:
ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
shift += ref_img_len
idx += 1
combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
return combined_img_hidden_states
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
batch_size = len(hidden_states)
p = self.config.patch_size
device = hidden_states[0].device
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
if ref_image_hidden_states is not None:
ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
else:
ref_img_sizes = [None for _ in range(batch_size)]
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
max_img_len = max(l_effective_img_len)
# ref image patch embeddings
flat_ref_img_hidden_states = []
for i in range(batch_size):
if ref_img_sizes[i] is not None:
imgs = []
for ref_img in ref_image_hidden_states[i]:
C, H, W = ref_img.size()
ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
imgs.append(ref_img)
img = torch.cat(imgs, dim=0)
flat_ref_img_hidden_states.append(img)
else:
flat_ref_img_hidden_states.append(None)
# image patch embeddings
flat_hidden_states = []
for i in range(batch_size):
img = hidden_states[i]
C, H, W = img.size()
img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
flat_hidden_states.append(img)
padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
for i in range(batch_size):
if ref_img_sizes[i] is not None:
padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
for i in range(batch_size):
padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
padded_img_mask[i, :l_effective_img_len[i]] = True
return (
padded_hidden_states,
padded_ref_img_hidden_states,
padded_img_mask,
padded_ref_img_mask,
l_effective_ref_img_len,
l_effective_img_len,
ref_img_sizes,
img_sizes,
)
def forward(
self,
hidden_states: Union[torch.Tensor, List[torch.Tensor]],
timestep: torch.Tensor,
text_hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
text_attention_mask: torch.Tensor,
ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# 1. Condition, positional & patch embedding
batch_size = len(hidden_states)
is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
if is_hidden_states_tensor:
assert hidden_states.ndim == 4
hidden_states = [_hidden_states for _hidden_states in hidden_states]
device = hidden_states[0].device
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
(
hidden_states,
ref_image_hidden_states,
img_mask,
ref_img_mask,
l_effective_ref_img_len,
l_effective_img_len,
ref_img_sizes,
img_sizes,
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
(
context_rotary_emb,
ref_img_rotary_emb,
noise_rotary_emb,
rotary_emb,
encoder_seq_lengths,
seq_lengths,
) = self.rope_embedder(
freqs_cis,
text_attention_mask,
l_effective_ref_img_len,
l_effective_img_len,
ref_img_sizes,
img_sizes,
device,
)
# 2. Context refinement
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
combined_img_hidden_states = self.img_patch_embed_and_refine(
hidden_states,
ref_image_hidden_states,
img_mask,
ref_img_mask,
noise_rotary_emb,
ref_img_rotary_emb,
l_effective_ref_img_len,
l_effective_img_len,
temb,
)
# 3. Joint Transformer blocks
max_seq_len = int(max(seq_lengths))
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
encoder_seq_len = int(encoder_seq_len)
seq_len = int(seq_len)
attention_mask[i, :seq_len] = True
joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
hidden_states = joint_hidden_states
for layer_idx, layer in enumerate(self.layers):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer, hidden_states, attention_mask, rotary_emb, temb
)
else:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
# 4. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb)
p = self.config.patch_size
output = []
for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
img_len = int(img_len)
seq_len = int(seq_len)
height, width = img_size
output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
if is_hidden_states_tensor:
output = torch.stack(output, dim=0)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return output
return Transformer2DModelOutput(sample=output)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,266 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
from diffusers.configuration_utils import register_to_config
class OmniGen2ImageProcessor(VaeImageProcessor):
"""
Image processor for PixArt image resize and crop.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `False`):
Whether to binarize the image to 0/1.
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to grayscale format.
"""
@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 16,
resample: str = "lanczos",
max_pixels: Optional[int] = None,
max_side_length: Optional[int] = None,
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_grayscale: bool = False,
):
super().__init__(
do_resize=do_resize,
vae_scale_factor=vae_scale_factor,
resample=resample,
do_normalize=do_normalize,
do_binarize=do_binarize,
do_convert_grayscale=do_convert_grayscale,
)
self.max_pixels = max_pixels
self.max_side_length = max_side_length
def get_new_height_width(
self,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
max_pixels: Optional[int] = None,
max_side_length: Optional[int] = None,
) -> Tuple[int, int]:
r"""
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
tensor, it should have shape `[batch, channels, height, width]`.
height (`Optional[int]`, *optional*, defaults to `None`):
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
width (`Optional[int]`, *optional*, defaults to `None`):
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor`.
"""
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
else:
height = image.shape[1]
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
else:
width = image.shape[2]
if max_side_length is None:
max_side_length = self.max_side_length
if max_pixels is None:
max_pixels = self.max_pixels
ratio = 1.0
if max_side_length is not None:
if height > width:
max_side_length_ratio = max_side_length / height
else:
max_side_length_ratio = max_side_length / width
cur_pixels = height * width
max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5
ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image
new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor
return new_height, new_width
def preprocess(
self,
image: PipelineImageInput,
height: Optional[int] = None,
width: Optional[int] = None,
max_pixels: Optional[int] = None,
max_side_length: Optional[int] = None,
resize_mode: str = "default", # "default", "fill", "crop"
crops_coords: Optional[Tuple[int, int, int, int]] = None,
) -> torch.Tensor:
"""
Preprocess the image input.
Args:
image (`PipelineImageInput`):
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
supported formats.
height (`int`, *optional*):
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
height.
width (`int`, *optional*):
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
Returns:
`torch.Tensor`:
The preprocessed image.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
if isinstance(image, torch.Tensor):
# if image is a pytorch tensor could have 2 possible shapes:
# 1. batch x height x width: we should insert the channel dimension at position 1
# 2. channel x height x width: we should insert batch dimension at position 0,
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
image = image.unsqueeze(1)
else:
# if it is a numpy array, it could have 2 possible shapes:
# 1. batch x height x width: insert channel dimension on last position
# 2. height x width x channel: insert batch dimension on first position
if image.shape[-1] == 1:
image = np.expand_dims(image, axis=0)
else:
image = np.expand_dims(image, axis=-1)
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d np.ndarray is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
FutureWarning,
)
image = np.concatenate(image, axis=0)
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d torch.Tensor is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
FutureWarning,
)
image = torch.cat(image, axis=0)
if not is_valid_image_imagelist(image):
raise ValueError(
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
)
if not isinstance(image, list):
image = [image]
if isinstance(image[0], PIL.Image.Image):
if crops_coords is not None:
image = [i.crop(crops_coords) for i in image]
if self.config.do_resize:
height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
if self.config.do_convert_rgb:
image = [self.convert_to_rgb(i) for i in image]
elif self.config.do_convert_grayscale:
image = [self.convert_to_grayscale(i) for i in image]
image = self.pil_to_numpy(image) # to np
image = self.numpy_to_pt(image) # to pt
elif isinstance(image[0], np.ndarray):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
image = self.numpy_to_pt(image)
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
if self.config.do_resize:
image = self.resize(image, height, width)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
if self.config.do_convert_grayscale and image.ndim == 3:
image = image.unsqueeze(1)
channel = image.shape[1]
# don't need any preprocess if the image is latents
if channel == self.config.vae_latent_channels:
return image
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
if self.config.do_resize:
image = self.resize(image, height, width)
# expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize
if do_normalize and image.min() < 0:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
FutureWarning,
)
do_normalize = False
if do_normalize:
image = self.normalize(image)
if self.config.do_binarize:
image = self.binarize(image)
return image

View File

@@ -0,0 +1,728 @@
"""
OmniGen2 Diffusion Pipeline
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import math
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from transformers import Qwen2_5_VLForConditionalGeneration
from diffusers.models.autoencoders import AutoencoderKL
from ...models.transformers import OmniGen2Transformer2DModel
from ...models.transformers.repo import OmniGen2RotaryPosEmbed
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
is_torch_xla_available,
logging,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from dataclasses import dataclass
import PIL.Image
from diffusers.utils import BaseOutput
from ....src.pipelines.image_processor import OmniGen2ImageProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FMPipelineOutput(BaseOutput):
"""
Output class for OmniGen2 pipeline.
Args:
images (Union[List[PIL.Image.Image], np.ndarray]):
List of denoised PIL images of length `batch_size` or numpy array of shape
`(batch_size, height, width, num_channels)`. Contains the generated images.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class OmniGen2Pipeline(DiffusionPipeline):
"""
Pipeline for text-to-image generation using OmniGen2.
This pipeline implements a text-to-image generation model that uses:
- Qwen2.5-VL for text encoding
- A custom transformer architecture for image generation
- VAE for image encoding/decoding
- FlowMatchEulerDiscreteScheduler for noise scheduling
Args:
transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
vae (AutoencoderKL): The VAE model for image encoding/decoding.
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
text_encoder (Qwen2_5_VLModel): The text encoder model.
tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
"""
model_cpu_offload_seq = "mllm->transformer->vae"
def __init__(
self,
transformer: OmniGen2Transformer2DModel,
vae: AutoencoderKL,
scheduler: FlowMatchEulerDiscreteScheduler,
mllm: Qwen2_5_VLForConditionalGeneration,
processor,
) -> None:
"""
Initialize the OmniGen2 pipeline.
Args:
transformer: The transformer model for image generation.
vae: The VAE model for image encoding/decoding.
scheduler: The scheduler for noise scheduling.
text_encoder: The text encoder model.
tokenizer: The tokenizer for text processing.
"""
super().__init__()
self.register_modules(
transformer=transformer,
vae=vae,
scheduler=scheduler,
mllm=mllm,
processor=processor
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
self.default_sample_size = 128
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[torch.Generator],
latents: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Prepare the initial latents for the diffusion process.
Args:
batch_size: The number of images to generate.
num_channels_latents: The number of channels in the latent space.
height: The height of the generated image.
width: The width of the generated image.
dtype: The data type of the latents.
device: The device to place the latents on.
generator: The random number generator to use.
latents: Optional pre-computed latents to use instead of random initialization.
Returns:
torch.FloatTensor: The prepared latents tensor.
"""
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
return latents
def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
"""
Encode an image into the VAE latent space.
Args:
img: The input image tensor to encode.
Returns:
torch.FloatTensor: The encoded latent representation.
"""
z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
if self.vae.config.shift_factor is not None:
z0 = z0 - self.vae.config.shift_factor
if self.vae.config.scaling_factor is not None:
z0 = z0 * self.vae.config.scaling_factor
z0 = z0.to(dtype=self.vae.dtype)
return z0
def prepare_image(
self,
images: Union[List[PIL.Image.Image], PIL.Image.Image],
batch_size: int,
num_images_per_prompt: int,
max_pixels: int,
max_side_length: int,
device: torch.device,
dtype: torch.dtype,
) -> List[Optional[torch.FloatTensor]]:
"""
Prepare input images for processing by encoding them into the VAE latent space.
Args:
images: Single image or list of images to process.
batch_size: The number of images to generate per prompt.
num_images_per_prompt: The number of images to generate for each prompt.
device: The device to place the encoded latents on.
dtype: The data type of the encoded latents.
Returns:
List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
"""
if batch_size == 1:
images = [images]
latents = []
for i, img in enumerate(images):
if img is not None and len(img) > 0:
ref_latents = []
for j, img_j in enumerate(img):
img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
else:
ref_latents = None
for _ in range(num_images_per_prompt):
latents.append(ref_latents)
return latents
def _get_qwen2_prompt_embeds(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
max_sequence_length: int = 256,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get prompt embeddings from the Qwen2 text encoder.
Args:
prompt: The prompt or list of prompts to encode.
device: The device to place the embeddings on. If None, uses the pipeline's device.
max_sequence_length: Maximum sequence length for tokenization.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The prompt embeddings tensor
- The attention mask tensor
Raises:
Warning: If the input text is truncated due to sequence length limitations.
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
# text_inputs = self.processor.tokenizer(
# prompt,
# padding="max_length",
# max_length=max_sequence_length,
# truncation=True,
# return_tensors="pt",
# )
text_inputs = self.processor.tokenizer(
prompt,
padding="longest",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because Gemma can only handle sequences up to"
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.mllm(
text_input_ids,
attention_mask=prompt_attention_mask,
output_hidden_states=True,
).hidden_states[-1]
if self.mllm is not None:
dtype = self.mllm.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
def _apply_chat_template(self, prompt: str):
prompt = [
{
"role": "system",
"content": "You are a helpful assistant that generates high-quality images based on user instructions.",
},
{"role": "user", "content": prompt},
]
prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
return prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
Lumina-T2I, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
max_sequence_length (`int`, defaults to `256`):
Maximum sequence length to use for the prompt.
"""
device = device or self._execution_device
if prompt is not None:
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [self._apply_chat_template(_prompt) for _prompt in prompt]
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
prompt=prompt,
device=device,
max_sequence_length=max_sequence_length
)
batch_size, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
# Get negative embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt if negative_prompt is not None else ""
# Normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
prompt=negative_prompt,
device=device,
max_sequence_length=max_sequence_length,
)
batch_size, seq_len, _ = negative_prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
batch_size * num_images_per_prompt, -1
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
@property
def num_timesteps(self):
return self._num_timesteps
@property
def text_guidance_scale(self):
return self._text_guidance_scale
@property
def image_guidance_scale(self):
return self._image_guidance_scale
@property
def cfg_range(self):
return self._cfg_range
@torch.no_grad()
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
max_sequence_length: Optional[int] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
input_images: Optional[List[PIL.Image.Image]] = None,
num_images_per_prompt: int = 1,
height: Optional[int] = None,
width: Optional[int] = None,
max_pixels: int = 1024 * 1024,
max_input_image_side_length: int = 1024,
align_res: bool = True,
num_inference_steps: int = 28,
text_guidance_scale: float = 4.0,
image_guidance_scale: float = 1.0,
cfg_range: Tuple[float, float] = (0.0, 1.0),
attention_kwargs: Optional[Dict[str, Any]] = None,
timesteps: List[int] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
verbose: bool = False,
step_func=None,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self._text_guidance_scale = text_guidance_scale
self._image_guidance_scale = image_guidance_scale
self._cfg_range = cfg_range
self._attention_kwargs = attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
self.text_guidance_scale > 1.0,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
)
dtype = self.vae.dtype
# 3. Prepare control image
ref_latents = self.prepare_image(
images=input_images,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
max_pixels=max_pixels,
max_side_length=max_input_image_side_length,
device=device,
dtype=dtype,
)
if input_images is None:
input_images = []
if len(input_images) == 1 and align_res:
width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
ori_width, ori_height = width, height
else:
ori_width, ori_height = width, height
cur_pixels = height * width
ratio = (max_pixels / cur_pixels) ** 0.5
ratio = min(ratio, 1.0)
height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
if len(input_images) == 0:
self._image_guidance_scale = 1
# 4. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
self.transformer.config.axes_dim_rope,
self.transformer.config.axes_lens,
theta=10000,
)
image = self.processing(
latents=latents,
ref_latents=ref_latents,
prompt_embeds=prompt_embeds,
freqs_cis=freqs_cis,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_inference_steps=num_inference_steps,
timesteps=timesteps,
device=device,
dtype=dtype,
verbose=verbose,
step_func=step_func,
)
image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return image
else:
return FMPipelineOutput(images=image)
def processing(
self,
latents,
ref_latents,
prompt_embeds,
freqs_cis,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
num_inference_steps,
timesteps,
device,
dtype,
verbose,
step_func=None
):
batch_size = latents.shape[0]
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
num_tokens=latents.shape[-2] * latents.shape[-1]
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
model_pred = self.predict(
t=t,
latents=latents,
prompt_embeds=prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
model_pred_ref = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
if image_guidance_scale != 1:
model_pred_uncond = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
)
else:
model_pred_uncond = torch.zeros_like(model_pred)
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
text_guidance_scale * (model_pred - model_pred_ref)
elif text_guidance_scale > 1.0:
model_pred_uncond = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
)
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
latents = latents.to(dtype=dtype)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if step_func is not None:
step_func(i, self._num_timesteps)
latents = latents.to(dtype=dtype)
if self.vae.config.scaling_factor is not None:
latents = latents / self.vae.config.scaling_factor
if self.vae.config.shift_factor is not None:
latents = latents + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
return image
def predict(
self,
t,
latents,
prompt_embeds,
freqs_cis,
prompt_attention_mask,
ref_image_hidden_states,
):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
batch_size, num_channels_latents, height, width = latents.shape
optional_kwargs = {}
if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
model_pred = self.transformer(
latents,
timestep,
prompt_embeds,
freqs_cis,
prompt_attention_mask,
**optional_kwargs
)
return model_pred

View File

@@ -0,0 +1,830 @@
"""
OmniGen2 Diffusion Pipeline
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import math
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from transformers import Qwen2_5_VLForConditionalGeneration
from diffusers.models.autoencoders import AutoencoderKL
from ...models.transformers import OmniGen2Transformer2DModel
from ...models.transformers.repo import OmniGen2RotaryPosEmbed
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
is_torch_xla_available,
logging,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from dataclasses import dataclass
import PIL.Image
from diffusers.utils import BaseOutput
from src.pipelines.image_processor import OmniGen2ImageProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class OmniGen2PipelineOutput(BaseOutput):
"""
Output class for OmniGen2 pipeline.
Args:
images (Union[List[PIL.Image.Image], np.ndarray]):
List of denoised PIL images of length `batch_size` or numpy array of shape
`(batch_size, height, width, num_channels)`. Contains the generated images.
"""
text: str
images: Union[List[PIL.Image.Image], np.ndarray]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class OmniGen2ChatPipeline(DiffusionPipeline):
"""
Pipeline for text-to-image generation using OmniGen2.
This pipeline implements a text-to-image generation model that uses:
- Qwen2.5-VL for text encoding
- A custom transformer architecture for image generation
- VAE for image encoding/decoding
- FlowMatchEulerDiscreteScheduler for noise scheduling
Args:
transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
vae (AutoencoderKL): The VAE model for image encoding/decoding.
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
text_encoder (Qwen2_5_VLModel): The text encoder model.
tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
"""
model_cpu_offload_seq = "mllm->transformer->vae"
def __init__(
self,
transformer: OmniGen2Transformer2DModel,
vae: AutoencoderKL,
scheduler: FlowMatchEulerDiscreteScheduler,
mllm: Qwen2_5_VLForConditionalGeneration,
processor,
) -> None:
"""
Initialize the OmniGen2 pipeline.
Args:
transformer: The transformer model for image generation.
vae: The VAE model for image encoding/decoding.
scheduler: The scheduler for noise scheduling.
text_encoder: The text encoder model.
tokenizer: The tokenizer for text processing.
"""
super().__init__()
self.register_modules(
transformer=transformer,
vae=vae,
scheduler=scheduler,
mllm=mllm,
processor=processor
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
self.default_sample_size = 128
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[torch.Generator],
latents: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Prepare the initial latents for the diffusion process.
Args:
batch_size: The number of images to generate.
num_channels_latents: The number of channels in the latent space.
height: The height of the generated image.
width: The width of the generated image.
dtype: The data type of the latents.
device: The device to place the latents on.
generator: The random number generator to use.
latents: Optional pre-computed latents to use instead of random initialization.
Returns:
torch.FloatTensor: The prepared latents tensor.
"""
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
return latents
def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
"""
Encode an image into the VAE latent space.
Args:
img: The input image tensor to encode.
Returns:
torch.FloatTensor: The encoded latent representation.
"""
z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
if self.vae.config.shift_factor is not None:
z0 = z0 - self.vae.config.shift_factor
if self.vae.config.scaling_factor is not None:
z0 = z0 * self.vae.config.scaling_factor
z0 = z0.to(dtype=self.vae.dtype)
return z0
def prepare_image(
self,
images: Union[List[PIL.Image.Image], PIL.Image.Image],
batch_size: int,
num_images_per_prompt: int,
max_pixels: int,
max_side_length: int,
device: torch.device,
dtype: torch.dtype,
) -> List[Optional[torch.FloatTensor]]:
"""
Prepare input images for processing by encoding them into the VAE latent space.
Args:
images: Single image or list of images to process.
batch_size: The number of images to generate per prompt.
num_images_per_prompt: The number of images to generate for each prompt.
device: The device to place the encoded latents on.
dtype: The data type of the encoded latents.
Returns:
List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
"""
if batch_size == 1:
images = [images]
latents = []
for i, img in enumerate(images):
if img is not None and len(img) > 0:
ref_latents = []
for j, img_j in enumerate(img):
img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
else:
ref_latents = None
for _ in range(num_images_per_prompt):
latents.append(ref_latents)
return latents
def _apply_chat_template(self, prompt: str, images: List = None):
if images is not None:
prompt = "".join(
[
f"<img{i}>: <|vision_start|><|image_pad|><|vision_end|>"
for i in range(1, len(images) + 1)
]
) + prompt
prompt = f"<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
return prompt
def _get_qwen2_prompt_embeds(
self,
prompt: Union[str, List[str]],
input_images = None,
device: Optional[torch.device] = None,
use_only_text_hidden_states: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get prompt embeddings from the Qwen2 text encoder.
Args:
prompt: The prompt or list of prompts to encode.
device: The device to place the embeddings on. If None, uses the pipeline's device.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The prompt embeddings tensor
- The attention mask tensor
Raises:
Warning: If the input text is truncated due to sequence length limitations.
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
inputs = self.processor(
text=prompt,
images=input_images,
videos=None,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
prompt_embeds = self.mllm(
**inputs,
output_hidden_states=True,
).hidden_states[-1]
text_input_ids = inputs.input_ids
text_mask = inputs.attention_mask
if use_only_text_hidden_states:
mask = text_input_ids != self.mllm.config.image_token_id
mask = mask & text_mask
mask = mask.bool()
text_l = mask.sum(dim=-1)
max_l = text_l.max()
text_batch_size = prompt_embeds.size(0)
new_prompt_embeds = torch.zeros((text_batch_size, max_l, prompt_embeds.size(-1)), device=prompt_embeds.device, dtype=prompt_embeds.dtype)
new_text_mask = torch.zeros((text_batch_size, max_l), dtype=text_mask.dtype, device=text_mask.device)
for i in range(text_batch_size):
new_prompt_embeds[i, :text_l[i]] = prompt_embeds[i, mask[i]]
new_text_mask[i, :text_l[i]] = 1
prompt_embeds = new_prompt_embeds
text_mask = new_text_mask
prompt_embeds = prompt_embeds.to(dtype=self.mllm.dtype, device=device)
return prompt_embeds, text_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
input_images: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
use_text_encoder_penultimate_layer_feats: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
Lumina-T2I, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
max_sequence_length (`int`, defaults to `256`):
Maximum sequence length to use for the prompt.
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
prompt=prompt,
input_images=input_images,
device=device,
)
batch_size, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
# Get negative embeddings for classifier free guidance
negative_prompt_embeds, negative_prompt_attention_mask = None, None
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt if negative_prompt is not None else ""
# Normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
prompt=negative_prompt,
device=device,
)
batch_size, seq_len, _ = negative_prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
batch_size * num_images_per_prompt, -1
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
@property
def num_timesteps(self):
return self._num_timesteps
@property
def text_guidance_scale(self):
return self._text_guidance_scale
@property
def image_guidance_scale(self):
return self._image_guidance_scale
@property
def cfg_range(self):
return self._cfg_range
def prepare_inputs_for_text_generation(self, prompts, input_images, device):
if isinstance(prompts, str):
prompts = [prompts]
ori_padding_side = self.processor.tokenizer.padding_side
self.processor.tokenizer.padding_side = "left"
inputs = self.processor(
text=prompts,
images=input_images,
videos=None,
padding=True,
return_tensors="pt",
).to(device)
self.processor.tokenizer.padding_side = ori_padding_side
return inputs
def generate_text(self, prompt, input_images):
inputs = self.prepare_inputs_for_text_generation(
prompt, input_images, self.mllm.device
)
generated_ids = self.mllm.generate(
**inputs,
tokenizer=self.processor.tokenizer,
max_new_tokens=256,
stop_strings=["<|im_end|>", "<|img|>", "<|endoftext|>"],
) # stop_words=[151643, 151645, 151665]
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_texts = self.processor.batch_decode(
generated_ids_trimmed,
# skip_special_tokens=True,
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
)
return output_texts
def generate_image(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
use_text_encoder_penultimate_layer_feats: bool = False,
max_sequence_length: Optional[int] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
input_images: Optional[List[PIL.Image.Image]] = None,
num_images_per_prompt: int = 1,
height: Optional[int] = None,
width: Optional[int] = None,
max_pixels: int = 1024 * 1024,
max_input_image_side_length: int = 1024,
align_res: bool = True,
num_inference_steps: int = 28,
text_guidance_scale: float = 4.0,
image_guidance_scale: float = 1.0,
cfg_range: Tuple[float, float] = (0.0, 1.0),
attention_kwargs: Optional[Dict[str, Any]] = None,
timesteps: List[int] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
verbose: bool = False,
step_func=None,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self._text_guidance_scale = text_guidance_scale
self._image_guidance_scale = image_guidance_scale
self._cfg_range = cfg_range
self._attention_kwargs = attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input promptb
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
input_images,
self.text_guidance_scale > 1.0,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats
)
dtype = self.vae.dtype
# 3. Prepare control image
ref_latents = self.prepare_image(
images=input_images,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
max_pixels=max_pixels,
max_side_length=max_input_image_side_length,
device=device,
dtype=dtype,
)
if input_images is None:
input_images = []
if len(input_images) == 1 and align_res:
width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
ori_width, ori_height = width, height
else:
ori_width, ori_height = width, height
cur_pixels = height * width
ratio = (max_pixels / cur_pixels) ** 0.5
ratio = min(ratio, 1.0)
height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
if len(input_images) == 0:
self._image_guidance_scale = 1
# 4. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
self.transformer.config.axes_dim_rope,
self.transformer.config.axes_lens,
theta=10000,
)
image = self.processing(
latents=latents,
ref_latents=ref_latents,
prompt_embeds=prompt_embeds,
freqs_cis=freqs_cis,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_inference_steps=num_inference_steps,
timesteps=timesteps,
device=device,
dtype=dtype,
verbose=verbose,
step_func=step_func,
)
image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
return image
@torch.no_grad()
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
use_text_encoder_penultimate_layer_feats: bool = False,
max_sequence_length: Optional[int] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
input_images: Optional[List[PIL.Image.Image]] = None,
num_images_per_prompt: int = 1,
height: Optional[int] = 1024,
width: Optional[int] = 1024,
max_pixels: Optional[int] = 1024 * 1024,
max_input_image_side_length: int = 1024,
align_res: bool = True,
num_inference_steps: int = 28,
text_guidance_scale: float = 4.0,
image_guidance_scale: float = 1.0,
cfg_range: Tuple[float, float] = (0.0, 1.0),
attention_kwargs: Optional[Dict[str, Any]] = None,
timesteps: List[int] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
verbose: bool = False,
step_func=None,
):
assert isinstance(prompt, str), "prompt must be a string since chat mode only support one prompt per turn"
# input_images = self.preprocess_images(input_images, max_input_image_size)
prompt = self._apply_chat_template(prompt, input_images)
generated_text = self.generate_text(prompt, input_images)[0]
images = None
if generated_text.startswith("<|img|>"):
#TODO: reuse the hidden state when generate text instead of re-generating
prompt = prompt + generated_text.split("<|img|>")[0]
images = self.generate_image(
prompt=prompt,
negative_prompt=negative_prompt,
use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats,
max_sequence_length=max_sequence_length,
input_images=input_images,
num_images_per_prompt=num_images_per_prompt,
height=height,
width=width,
max_pixels=max_pixels,
max_input_image_side_length=max_input_image_side_length,
align_res=align_res,
num_inference_steps=num_inference_steps,
text_guidance_scale=text_guidance_scale,
image_guidance_scale=image_guidance_scale,
cfg_range=cfg_range,
timesteps=timesteps,
generator=generator,
latents=latents,
return_dict=False,
verbose=verbose,
step_func=step_func,
)
generated_text = generated_text.replace("<|im_end|>", "")
if not return_dict:
return generated_text, images
else:
return OmniGen2PipelineOutput(text=generated_text, images=images)
def processing(
self,
latents,
ref_latents,
prompt_embeds,
freqs_cis,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
num_inference_steps,
timesteps,
device,
dtype,
verbose,
step_func=None
):
batch_size = latents.shape[0]
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
num_tokens=latents.shape[-2] * latents.shape[-1]
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
model_pred = self.predict(
t=t,
latents=latents,
prompt_embeds=prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
model_pred_ref = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
if image_guidance_scale != 1:
model_pred_uncond = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
)
else:
model_pred_uncond = torch.zeros_like(model_pred)
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
text_guidance_scale * (model_pred - model_pred_ref)
elif text_guidance_scale > 1.0:
model_pred_uncond = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
)
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
latents = latents.to(dtype=dtype)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if step_func is not None:
step_func(i, self._num_timesteps)
latents = latents.to(dtype=dtype)
if self.vae.config.scaling_factor is not None:
latents = latents / self.vae.config.scaling_factor
if self.vae.config.shift_factor is not None:
latents = latents + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
return image
def predict(
self,
t,
latents,
prompt_embeds,
freqs_cis,
prompt_attention_mask,
ref_image_hidden_states,
):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
batch_size, num_channels_latents, height, width = latents.shape
optional_kwargs = {}
if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
model_pred = self.transformer(
latents,
timestep,
prompt_embeds,
freqs_cis,
prompt_attention_mask,
**optional_kwargs
)
return model_pred

View File

@@ -0,0 +1,62 @@
import torch
def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
""" Get pipeline embeds for prompts bigger than the maxlength of the pipe
:param pipeline:
:param prompt:
:param negative_prompt:
:param device:
:return:
"""
max_length = pipeline.tokenizer.model_max_length
# simple way to determine length of tokens
# count_prompt = len(prompt.split(" "))
# count_negative_prompt = len(negative_prompt.split(" "))
# create the tensor based on which prompt is longer
# if count_prompt >= count_negative_prompt:
input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device)
# input_ids = pipeline.tokenizer(prompt, padding="max_length",
# max_length=pipeline.tokenizer.model_max_length,
# truncation=True,
# return_tensors="pt",).input_ids.to(device)
shape_max_length = input_ids.shape[-1]
if negative_prompt is not None:
negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length",
max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
# else:
# negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
# shape_max_length = negative_ids.shape[-1]
# input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
# max_length=shape_max_length).input_ids.to(device)
concat_embeds = []
neg_embeds = []
for i in range(0, shape_max_length, max_length):
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device)
else:
attention_mask = None
concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length],
attention_mask=attention_mask)[0])
if negative_prompt is not None:
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device)
else:
attention_mask = None
neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length],
attention_mask=attention_mask)[0])
concat_embeds = torch.cat(concat_embeds, dim=1)
if negative_prompt is not None:
neg_embeds = torch.cat(neg_embeds, dim=1)
else:
neg_embeds = None
return concat_embeds, neg_embeds

View File

@@ -0,0 +1,229 @@
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.schedulers.scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
dynamic_time_shift: bool = True
):
timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
self.timesteps = timesteps
self._step_index = None
self._begin_index = None
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self._timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
# def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
# return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[float]] = None,
num_tokens: Optional[int] = None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if timesteps is None:
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
if self.config.dynamic_time_shift and num_tokens is not None:
m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
timesteps = timesteps / (m - m * timesteps + timesteps)
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
_timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
self.timesteps = timesteps
self._timesteps = _timesteps
self._step_index = None
self._begin_index = None
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
t = self._timesteps[self.step_index]
t_next = self._timesteps[self.step_index + 1]
prev_sample = sample + (t_next - t) * model_output
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -0,0 +1,31 @@
from typing import List
from PIL import Image
import torch
from torchvision.transforms.functional import to_pil_image
def resize_image(image, max_pixels, img_scale_num):
width, height = image.size
cur_pixels = height * width
ratio = (max_pixels / cur_pixels) ** 0.5
ratio = min(ratio, 1.0) # do not upscale input image
new_height, new_width = int(height * ratio) // img_scale_num * img_scale_num, int(width * ratio) // img_scale_num * img_scale_num
image = image.resize((new_width, new_height), resample=Image.BICUBIC)
return image
def create_collage(images: List[torch.Tensor]) -> Image.Image:
"""Create a horizontal collage from a list of images."""
max_height = max(img.shape[-2] for img in images)
total_width = sum(img.shape[-1] for img in images)
canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
current_x = 0
for img in images:
h, w = img.shape[-2:]
canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
current_x += w
return to_pil_image(canvas)

View File

@@ -0,0 +1,46 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Import utilities: Utilities related to imports and our lazy inits.
"""
import importlib.util
import sys
# The package importlib_metadata is in a different place, depending on the python version.
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
def _is_package_available(pkg_name: str):
pkg_exists = importlib.util.find_spec(pkg_name) is not None
pkg_version = "N/A"
if pkg_exists:
try:
pkg_version = importlib_metadata.version(pkg_name)
except (ImportError, importlib_metadata.PackageNotFoundError):
pkg_exists = False
return pkg_exists, pkg_version
_triton_available, _triton_version = _is_package_available("triton")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
def is_triton_available():
return _triton_available
def is_flash_attn_available():
return _flash_attn_available

View File

@@ -34,7 +34,7 @@ def get_all_extensions() -> List[Extension]:
for sub_dir in extension_folders: for sub_dir in extension_folders:
extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir) extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir)
for (_, name, _) in pkgutil.iter_modules([extensions_dir]): for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
try: # try:
# Import the module # Import the module
module = importlib.import_module(f"{sub_dir}.{name}") module = importlib.import_module(f"{sub_dir}.{name}")
# Get the value of the AI_TOOLKIT_EXTENSIONS variable # Get the value of the AI_TOOLKIT_EXTENSIONS variable
@@ -43,8 +43,8 @@ def get_all_extensions() -> List[Extension]:
if isinstance(extensions, list): if isinstance(extensions, list):
# Iterate over the list and add the classes to the main list # Iterate over the list and add the classes to the main list
all_extension_classes.extend(extensions) all_extension_classes.extend(extensions)
except ImportError as e: # except ImportError as e:
print(f"Failed to import the {name} module. Error: {str(e)}") # print(f"Failed to import the {name} module. Error: {str(e)}")
return all_extension_classes return all_extension_classes

View File

@@ -170,6 +170,19 @@ export const modelArchs: ModelArch[] = [
}, },
disableSections: ['model.quantize', 'train.timestep_type'], disableSections: ['model.quantize', 'train.timestep_type'],
}, },
{
name: 'omnigen2',
label: 'OmniGen2',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [true, false],
},
disableSections: ['network.conv'],
},
].sort((a, b) => { ].sort((a, b) => {
// Sort by label, case-insensitive // Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }) return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })

View File

@@ -1 +1 @@
VERSION = "0.3.1" VERSION = "0.3.2"