mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Added support for finetuning OmniGen2.
This commit is contained in:
94
config/examples/train_lora_omnigen2_24gb.yaml
Normal file
94
config/examples/train_lora_omnigen2_24gb.yaml
Normal file
@@ -0,0 +1,94 @@
|
||||
---
|
||||
job: extension
|
||||
config:
|
||||
# this name will be the folder and filename name
|
||||
name: "my_first_omnigen2_lora_v1"
|
||||
process:
|
||||
- type: 'sd_trainer'
|
||||
# root folder to save training sessions/samples/weights
|
||||
training_folder: "output"
|
||||
# uncomment to see performance stats in the terminal every N steps
|
||||
# performance_log_every: 1000
|
||||
device: cuda:0
|
||||
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
||||
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
||||
# trigger_word: "p3r5on"
|
||||
network:
|
||||
type: "lora"
|
||||
linear: 16
|
||||
linear_alpha: 16
|
||||
save:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 250 # save every this many steps
|
||||
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
||||
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
||||
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
||||
# hf_repo_id: your-username/your-model-slug
|
||||
# hf_private: true #whether the repo is private or public
|
||||
datasets:
|
||||
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
||||
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
||||
# images will automatically be resized and bucketed into the resolution specified
|
||||
# on windows, escape back slashes with another backslash so
|
||||
# "C:\\path\\to\\images\\folder"
|
||||
- folder_path: "/path/to/images/folder"
|
||||
caption_ext: "txt"
|
||||
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
||||
shuffle_tokens: false # shuffle caption order, split by commas
|
||||
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
||||
resolution: [ 512, 768, 1024 ] # omnigen2 should work with multiple resolutions
|
||||
train:
|
||||
batch_size: 1
|
||||
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
||||
gradient_accumulation: 1
|
||||
train_unet: true
|
||||
train_text_encoder: false # probably won't work with omnigen2
|
||||
gradient_checkpointing: true # need the on unless you have a ton of vram
|
||||
noise_scheduler: "flowmatch" # for training only
|
||||
optimizer: "adamw8bit"
|
||||
lr: 1e-4
|
||||
timestep_type: 'sigmoid' # sigmoid, linear, shift
|
||||
# uncomment this to skip the pre training sample
|
||||
# skip_first_sample: true
|
||||
# uncomment to completely disable sampling
|
||||
# disable_sampling: true
|
||||
|
||||
# ema will smooth out learning, but could slow it down.
|
||||
# ema_config:
|
||||
# use_ema: true
|
||||
# ema_decay: 0.99
|
||||
|
||||
# will probably need this if gpu supports it for omnigen2, other dtypes may not work correctly
|
||||
dtype: bf16
|
||||
model:
|
||||
name_or_path: "OmniGen2/OmniGen2
|
||||
arch: "omnigen2"
|
||||
quantize_te: true # quantize_only te
|
||||
# quantize: true # quantize transformer
|
||||
sample:
|
||||
sampler: "flowmatch" # must match train.noise_scheduler
|
||||
sample_every: 250 # sample every this many steps
|
||||
width: 1024
|
||||
height: 1024
|
||||
prompts:
|
||||
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
||||
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
||||
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
||||
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
||||
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
||||
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
||||
- "a bear building a log cabin in the snow covered mountains"
|
||||
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
||||
- "hipster man with a beard, building a chair, in a wood shop"
|
||||
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
||||
- "a man holding a sign that says, 'this is a sign'"
|
||||
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
||||
neg: "" # negative prompt, optional
|
||||
seed: 42
|
||||
walk_seed: true
|
||||
guidance_scale: 4
|
||||
sample_steps: 25
|
||||
# you can add any additional meta info here. [name] is replaced with config name at top
|
||||
meta:
|
||||
name: "[name]"
|
||||
version: '1.0'
|
||||
@@ -1,8 +1,9 @@
|
||||
from .chroma import ChromaModel
|
||||
from .hidream import HidreamModel
|
||||
from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
# put a list of models here
|
||||
ChromaModel, HidreamModel, FLiteModel
|
||||
ChromaModel, HidreamModel, FLiteModel, OmniGen2Model
|
||||
]
|
||||
|
||||
327
extensions_built_in/diffusion_models/omnigen2/__init__.py
Normal file
327
extensions_built_in/diffusion_models/omnigen2/__init__.py
Normal file
@@ -0,0 +1,327 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import yaml
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from diffusers import AutoencoderKL
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from optimum.quanto import freeze
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
|
||||
from .src.models.transformers import OmniGen2Transformer2DModel
|
||||
from .src.models.transformers.repo import OmniGen2RotaryPosEmbed
|
||||
from .src.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler
|
||||
from transformers import CLIPProcessor, Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
|
||||
scheduler_config = {
|
||||
"num_train_timesteps": 1000
|
||||
}
|
||||
|
||||
BASE_MODEL_PATH = "OmniGen2/OmniGen2"
|
||||
|
||||
|
||||
class OmniGen2Model(BaseModel):
|
||||
arch = "omnigen2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='bf16',
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
device,
|
||||
model_config,
|
||||
dtype,
|
||||
custom_pipeline,
|
||||
noise_scheduler,
|
||||
**kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ['OmniGen2Transformer2DModel']
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
return 16
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
# HiDream-ai/HiDream-I1-Full
|
||||
self.print_and_status_update("Loading OmniGen2 model")
|
||||
# will be updated if we detect a existing checkpoint in training folder
|
||||
model_path = self.model_config.name_or_path
|
||||
extras_path = self.model_config.extras_name_or_path
|
||||
|
||||
scheduler = OmniGen2Model.get_train_scheduler()
|
||||
|
||||
self.print_and_status_update("Loading Qwen2.5 VL")
|
||||
processor = CLIPProcessor.from_pretrained(
|
||||
extras_path,
|
||||
subfolder="processor",
|
||||
use_fast=True
|
||||
)
|
||||
|
||||
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
extras_path,
|
||||
subfolder="mllm",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Qwen2.5 VL model")
|
||||
quantization_type = get_qtype(self.model_config.qtype_te)
|
||||
quantize(mllm, weights=quantization_type)
|
||||
freeze(mllm)
|
||||
|
||||
if self.low_vram:
|
||||
# unload it for now
|
||||
mllm.to('cpu')
|
||||
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
|
||||
transformer = OmniGen2Transformer2DModel.from_pretrained(
|
||||
model_path,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
if not self.low_vram:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.model_config.quantize:
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
quantize(transformer, weights=quantization_type)
|
||||
freeze(transformer)
|
||||
|
||||
if self.low_vram:
|
||||
# unload it for now
|
||||
transformer.to('cpu')
|
||||
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading vae")
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
extras_path,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.bfloat16
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
|
||||
flush()
|
||||
self.print_and_status_update("Loading Qwen2.5 VLProcessor")
|
||||
|
||||
flush()
|
||||
|
||||
if self.low_vram:
|
||||
self.print_and_status_update("Moving everything to device")
|
||||
# move it all back
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
vae.to(self.device_torch, dtype=dtype)
|
||||
mllm.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# set to eval mode
|
||||
# transformer.eval()
|
||||
vae.eval()
|
||||
mllm.eval()
|
||||
mllm.requires_grad_(False)
|
||||
|
||||
pipe: OmniGen2Pipeline = OmniGen2Pipeline(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
mllm=mllm,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
# pipe: OmniGen2Pipeline = OmniGen2Pipeline.from_pretrained(
|
||||
# model_path,
|
||||
# transformer=transformer,
|
||||
# vae=vae,
|
||||
# scheduler=scheduler,
|
||||
# mllm=mllm,
|
||||
# trust_remote_code=True,
|
||||
# )
|
||||
# processor = pipe.processor
|
||||
|
||||
flush()
|
||||
|
||||
text_encoder_list = [mllm]
|
||||
tokenizer_list = [processor]
|
||||
|
||||
|
||||
flush()
|
||||
|
||||
# save it to the model class
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder_list # list of text encoders
|
||||
self.tokenizer = tokenizer_list # list of tokenizers
|
||||
self.model = pipe.transformer
|
||||
self.pipeline = pipe
|
||||
|
||||
self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
|
||||
transformer.config.axes_dim_rope,
|
||||
transformer.config.axes_lens,
|
||||
theta=10000,
|
||||
)
|
||||
|
||||
self.print_and_status_update("Model Loaded")
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = OmniFlowMatchEuler(
|
||||
dynamic_time_shift=True,
|
||||
num_train_timesteps=1000
|
||||
)
|
||||
|
||||
pipeline: OmniGen2Pipeline = OmniGen2Pipeline(
|
||||
transformer=self.model,
|
||||
vae=self.vae,
|
||||
scheduler=scheduler,
|
||||
mllm=self.text_encoder[0],
|
||||
processor=self.tokenizer[0],
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: OmniGen2Pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
prompt_attention_mask=conditional_embeds.attention_mask,
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
text_guidance_scale=gen_config.guidance_scale,
|
||||
image_guidance_scale=1.0, # reference image guidance scale. Add this for controls
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
**extra
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
|
||||
# optional_kwargs = {}
|
||||
# if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()):
|
||||
# optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
|
||||
|
||||
timesteps = timestep / 1000 # convert to 0 to 1 scale
|
||||
# timestep for model starts at 0 instead of 1. So we need to reverse them
|
||||
timestep = 1 - timesteps
|
||||
model_pred = self.model(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings.text_embeds,
|
||||
self.freqs_cis,
|
||||
text_embeddings.attention_mask,
|
||||
ref_image_hidden_states=None, # todo add ref latent ability
|
||||
)
|
||||
|
||||
return model_pred
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt]
|
||||
self.text_encoder_to(self.device_torch, dtype=self.torch_dtype)
|
||||
max_sequence_length = 256
|
||||
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
|
||||
prompt = prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
device=self.device_torch,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
pe = PromptEmbeds(prompt_embeds)
|
||||
pe.attention_mask = prompt_attention_mask
|
||||
return pe
|
||||
|
||||
def get_model_has_grad(self):
|
||||
# return from a weight if it has grad
|
||||
return False
|
||||
|
||||
def get_te_has_grad(self):
|
||||
# assume no one wants to finetune 4 text encoders.
|
||||
return False
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# only save the transformer
|
||||
transformer: OmniGen2Transformer2DModel = unwrap_model(self.model)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_path, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get('noise')
|
||||
batch = kwargs.get('batch')
|
||||
# return (noise - batch.latents).detach()
|
||||
return (batch.latents - noise).detach()
|
||||
|
||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||
# omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers.
|
||||
# lets do all but image refiner until we add it
|
||||
return ['noise_refiner', 'context_refiner', 'layers']
|
||||
# return ['layers']
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("transformer.", "diffusion_model.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
|
||||
def convert_lora_weights_before_load(self, state_dict):
|
||||
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("diffusion_model.", "transformer.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "omnigen2"
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
OmniGen2 Attention Processor Module
|
||||
|
||||
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
import math
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
from ..utils.import_utils import is_flash_attn_available
|
||||
|
||||
if is_flash_attn_available():
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
||||
else:
|
||||
warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance")
|
||||
|
||||
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
|
||||
class OmniGen2AttnProcessorFlash2Varlen:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
|
||||
|
||||
This processor implements:
|
||||
- Flash attention with variable length sequences
|
||||
- Rotary position embeddings (RoPE)
|
||||
- Query-Key normalization
|
||||
- Proportional attention scaling
|
||||
|
||||
Args:
|
||||
None
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the attention processor."""
|
||||
if not is_flash_attn_available():
|
||||
raise ImportError(
|
||||
"OmniGen2AttnProcessorFlash2Varlen requires flash_attn. "
|
||||
"Please install flash_attn."
|
||||
)
|
||||
|
||||
def _upad_input(
|
||||
self,
|
||||
query_layer: torch.Tensor,
|
||||
key_layer: torch.Tensor,
|
||||
value_layer: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
num_heads: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
||||
"""
|
||||
Unpad the input tensors for flash attention.
|
||||
|
||||
Args:
|
||||
query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
|
||||
key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
|
||||
value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
|
||||
attention_mask: Attention mask tensor of shape (batch_size, seq_len)
|
||||
query_length: Length of the query sequence
|
||||
num_heads: Number of attention heads
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Unpadded query tensor
|
||||
- Unpadded key tensor
|
||||
- Unpadded value tensor
|
||||
- Query indices
|
||||
- Tuple of cumulative sequence lengths for query and key
|
||||
- Tuple of maximum sequence lengths for query and key
|
||||
"""
|
||||
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""Helper function to get unpadding data from attention mask."""
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
# Unpad key and value layers
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
||||
indices_k,
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
||||
indices_k,
|
||||
)
|
||||
|
||||
# Handle different query length cases
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
|
||||
indices_k,
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
)
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process attention computation with flash attention.
|
||||
|
||||
Args:
|
||||
attn: Attention module
|
||||
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
|
||||
encoder_hidden_states: Encoder hidden states tensor
|
||||
attention_mask: Optional attention mask tensor
|
||||
image_rotary_emb: Optional rotary embeddings for image tokens
|
||||
base_sequence_length: Optional base sequence length for proportional attention
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Processed hidden states after attention computation
|
||||
"""
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
# Get Query-Key-Value Pair
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
# Reshape tensors for attention computation
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply Query-Key normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply Rotary Position Embeddings
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Calculate attention scale
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# Unpad input for flash attention
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
indices_q,
|
||||
cu_seq_lens,
|
||||
max_seq_lens,
|
||||
) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
# Handle different number of heads
|
||||
if kv_heads < attn.heads:
|
||||
key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
||||
value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
||||
|
||||
# Apply flash attention
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
softmax_scale=softmax_scale,
|
||||
)
|
||||
|
||||
# Pad output and apply final transformations
|
||||
hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
|
||||
hidden_states = hidden_states.flatten(-2)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# Apply output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OmniGen2AttnProcessor:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
|
||||
|
||||
This processor is optimized for PyTorch 2.0 and implements:
|
||||
- Flash attention with variable length sequences
|
||||
- Rotary position embeddings (RoPE)
|
||||
- Query-Key normalization
|
||||
- Proportional attention scaling
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ImportError: If PyTorch version is less than 2.0
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the attention processor."""
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. "
|
||||
"Please upgrade PyTorch to version 2.0 or later."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process attention computation with flash attention.
|
||||
|
||||
Args:
|
||||
attn: Attention module
|
||||
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
|
||||
encoder_hidden_states: Encoder hidden states tensor
|
||||
attention_mask: Optional attention mask tensor
|
||||
image_rotary_emb: Optional rotary embeddings for image tokens
|
||||
base_sequence_length: Optional base sequence length for proportional attention
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Processed hidden states after attention computation
|
||||
"""
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
# Get Query-Key-Value Pair
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
# Reshape tensors for attention computation
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply Query-Key normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply Rotary Position Embeddings
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Calculate attention scale
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
|
||||
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
|
||||
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# Apply output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
from diffusers.models.activations import get_activation
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
nn.init.normal_(self.linear_1.weight, std=0.02)
|
||||
nn.init.zeros_(self.linear_1.bias)
|
||||
nn.init.normal_(self.linear_2.weight, std=0.02)
|
||||
nn.init.zeros_(self.linear_2.bias)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen and CogView4
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
# used for lumina
|
||||
# x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .transformer_omnigen2 import OmniGen2Transformer2DModel
|
||||
|
||||
__all__ = ["OmniGen2Transformer2DModel"]
|
||||
@@ -0,0 +1,218 @@
|
||||
|
||||
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusers.models.embeddings import Timesteps
|
||||
from ..embeddings import TimestepEmbedding
|
||||
|
||||
from ...utils.import_utils import is_flash_attn_available, is_triton_available
|
||||
|
||||
if is_triton_available():
|
||||
from ...ops.triton.layer_norm import RMSNorm
|
||||
else:
|
||||
from torch.nn import RMSNorm
|
||||
warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance")
|
||||
|
||||
if is_flash_attn_available():
|
||||
from flash_attn.ops.activations import swiglu
|
||||
else:
|
||||
from .components import swiglu
|
||||
warnings.warn("Cannot import flash_attn, install flash_attn to use fused SwiGLU for better performance")
|
||||
|
||||
# try:
|
||||
# from flash_attn.ops.activations import swiglu as fused_swiglu
|
||||
# FUSEDSWIGLU_AVALIBLE = True
|
||||
# except ImportError:
|
||||
|
||||
# FUSEDSWIGLU_AVALIBLE = False
|
||||
# warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||
|
||||
class LuminaRMSNormZero(nn.Module):
|
||||
"""
|
||||
Norm layer adaptive RMS normalization zero.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
norm_eps: float,
|
||||
norm_elementwise_affine: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(
|
||||
min(embedding_dim, 1024),
|
||||
4 * embedding_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(embedding_dim, eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None])
|
||||
return x, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class LuminaLayerNormContinuous(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
||||
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
||||
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
||||
# However, this is how it was implemented in the original code, and it's rather likely you should
|
||||
# set `elementwise_affine` to False.
|
||||
elementwise_affine=True,
|
||||
eps=1e-5,
|
||||
bias=True,
|
||||
norm_type="layer_norm",
|
||||
out_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# AdaLN
|
||||
self.silu = nn.SiLU()
|
||||
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
self.linear_2 = None
|
||||
if out_dim is not None:
|
||||
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
conditioning_embedding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
||||
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
||||
scale = emb
|
||||
x = self.norm(x) * (1 + scale)[:, None, :]
|
||||
|
||||
if self.linear_2 is not None:
|
||||
x = self.linear_2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LuminaFeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
hidden_size (`int`):
|
||||
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
||||
hidden representations.
|
||||
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
|
||||
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
|
||||
of this value.
|
||||
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
|
||||
dimension. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: int,
|
||||
multiple_of: Optional[int] = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.swiglu = swiglu
|
||||
|
||||
# custom hidden_size factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
||||
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.linear_1 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_2 = nn.Linear(
|
||||
inner_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_3 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h1, h2 = self.linear_1(x), self.linear_3(x)
|
||||
return self.linear_2(self.swiglu(h1, h2))
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 4096,
|
||||
text_feat_dim: int = 2048,
|
||||
frequency_embedding_size: int = 256,
|
||||
norm_eps: float = 1e-5,
|
||||
timestep_scale: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
|
||||
)
|
||||
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
||||
)
|
||||
|
||||
self.caption_embedder = nn.Sequential(
|
||||
RMSNorm(text_feat_dim, eps=norm_eps),
|
||||
nn.Linear(text_feat_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
|
||||
nn.init.zeros_(self.caption_embedder[1].bias)
|
||||
|
||||
def forward(
|
||||
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
||||
time_embed = self.timestep_embedder(timestep_proj)
|
||||
caption_embed = self.caption_embedder(text_hidden_states)
|
||||
return time_embed, caption_embed
|
||||
@@ -0,0 +1,4 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
def swiglu(x, y):
|
||||
return F.silu(x.float(), inplace=False).to(x.dtype) * y
|
||||
@@ -0,0 +1,135 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import repeat
|
||||
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
||||
|
||||
class OmniGen2RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, theta: int,
|
||||
axes_dim: Tuple[int, int, int],
|
||||
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||
patch_size: int = 2):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.axes_lens = axes_lens
|
||||
self.patch_size = patch_size
|
||||
|
||||
@staticmethod
|
||||
def get_freqs_cis(axes_dim: Tuple[int, int, int],
|
||||
axes_lens: Tuple[int, int, int],
|
||||
theta: int) -> List[torch.Tensor]:
|
||||
freqs_cis = []
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
||||
emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
|
||||
freqs_cis.append(emb)
|
||||
return freqs_cis
|
||||
|
||||
def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
|
||||
device = ids.device
|
||||
if ids.device.type == "mps":
|
||||
ids = ids.to("cpu")
|
||||
|
||||
result = []
|
||||
for i in range(len(self.axes_dim)):
|
||||
freqs = freqs_cis[i].to(ids.device)
|
||||
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
||||
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
||||
return torch.cat(result, dim=-1).to(device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
freqs_cis,
|
||||
attention_mask,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
ref_img_sizes,
|
||||
img_sizes,
|
||||
device
|
||||
):
|
||||
batch_size = len(attention_mask)
|
||||
p = self.patch_size
|
||||
|
||||
encoder_seq_len = attention_mask.shape[1]
|
||||
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
||||
|
||||
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
|
||||
|
||||
max_seq_len = int(max(seq_lengths))
|
||||
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
# Create position IDs
|
||||
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
cap_seq_len = int(cap_seq_len)
|
||||
seq_len = int(seq_len)
|
||||
# add text position ids
|
||||
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
|
||||
|
||||
pe_shift = cap_seq_len
|
||||
pe_shift_len = cap_seq_len
|
||||
|
||||
if ref_img_sizes[i] is not None:
|
||||
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
||||
H, W = ref_img_size
|
||||
ref_H_tokens, ref_W_tokens = H // p, W // p
|
||||
assert ref_H_tokens * ref_W_tokens == ref_img_len
|
||||
# add image position ids
|
||||
|
||||
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
|
||||
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
||||
|
||||
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
||||
pe_shift_len += ref_img_len
|
||||
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // p, W // p
|
||||
assert H_tokens * W_tokens == l_effective_img_len[i]
|
||||
|
||||
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
|
||||
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
|
||||
|
||||
assert pe_shift_len + l_effective_img_len[i] == seq_len
|
||||
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
||||
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
||||
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
||||
|
||||
# Get combined rotary embeddings
|
||||
freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
|
||||
|
||||
# create separate rotary embeddings for captions and images
|
||||
cap_freqs_cis = torch.zeros(
|
||||
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
||||
)
|
||||
ref_img_freqs_cis = torch.zeros(
|
||||
batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
||||
)
|
||||
img_freqs_cis = torch.zeros(
|
||||
batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
||||
)
|
||||
|
||||
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
|
||||
cap_seq_len = int(cap_seq_len)
|
||||
sum_ref_img_len = int(sum(ref_img_len))
|
||||
img_len = int(img_len)
|
||||
seq_len = int(seq_len)
|
||||
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
||||
ref_img_freqs_cis[i, :sum_ref_img_len] = freqs_cis[i, cap_seq_len:cap_seq_len + sum_ref_img_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum_ref_img_len:cap_seq_len + sum_ref_img_len + img_len]
|
||||
|
||||
return (
|
||||
cap_freqs_cis,
|
||||
ref_img_freqs_cis,
|
||||
img_freqs_cis,
|
||||
freqs_cis,
|
||||
l_effective_cap_len,
|
||||
seq_lengths,
|
||||
)
|
||||
@@ -0,0 +1,621 @@
|
||||
import warnings
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import PeftAdapterMixin
|
||||
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
||||
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen, OmniGen2AttnProcessor
|
||||
from .repo import OmniGen2RotaryPosEmbed
|
||||
from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding
|
||||
|
||||
from ...utils.import_utils import is_triton_available, is_flash_attn_available
|
||||
|
||||
if is_triton_available():
|
||||
from ...ops.triton.layer_norm import RMSNorm
|
||||
else:
|
||||
from torch.nn import RMSNorm
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class OmniGen2TransformerBlock(nn.Module):
|
||||
"""
|
||||
Transformer block for OmniGen2 model.
|
||||
|
||||
This block implements a transformer layer with:
|
||||
- Multi-head attention with flash attention
|
||||
- Feed-forward network with SwiGLU activation
|
||||
- RMS normalization
|
||||
- Optional modulation for conditional generation
|
||||
|
||||
Args:
|
||||
dim: Dimension of the input and output tensors
|
||||
num_attention_heads: Number of attention heads
|
||||
num_kv_heads: Number of key-value heads
|
||||
multiple_of: Multiple of which the hidden dimension should be
|
||||
ffn_dim_multiplier: Multiplier for the feed-forward network dimension
|
||||
norm_eps: Epsilon value for normalization layers
|
||||
modulation: Whether to use modulation for conditional generation
|
||||
use_fused_rms_norm: Whether to use fused RMS normalization
|
||||
use_fused_swiglu: Whether to use fused SwiGLU activation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
num_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
modulation: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the transformer block."""
|
||||
super().__init__()
|
||||
self.head_dim = dim // num_attention_heads
|
||||
self.modulation = modulation
|
||||
|
||||
try:
|
||||
processor = OmniGen2AttnProcessorFlash2Varlen()
|
||||
except ImportError:
|
||||
processor = OmniGen2AttnProcessor()
|
||||
|
||||
# Initialize attention layer
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // num_attention_heads,
|
||||
qk_norm="rms_norm",
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_kv_heads,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
# Initialize feed-forward network
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
dim=dim,
|
||||
inner_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier
|
||||
)
|
||||
|
||||
# Initialize normalization layers
|
||||
if modulation:
|
||||
self.norm1 = LuminaRMSNormZero(
|
||||
embedding_dim=dim,
|
||||
norm_eps=norm_eps,
|
||||
norm_elementwise_affine=True
|
||||
)
|
||||
else:
|
||||
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self) -> None:
|
||||
"""
|
||||
Initialize the weights of the transformer block.
|
||||
|
||||
Uses Xavier uniform initialization for linear layers and zero initialization for biases.
|
||||
"""
|
||||
nn.init.xavier_uniform_(self.attn.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.attn.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.attn.to_v.weight)
|
||||
nn.init.xavier_uniform_(self.attn.to_out[0].weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
|
||||
nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
|
||||
nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
|
||||
|
||||
if self.modulation:
|
||||
nn.init.zeros_(self.norm1.linear.weight)
|
||||
nn.init.zeros_(self.norm1.linear.bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the transformer block.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states tensor
|
||||
attention_mask: Attention mask tensor
|
||||
image_rotary_emb: Rotary embeddings for image tokens
|
||||
temb: Optional timestep embedding tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output hidden states after transformer block processing
|
||||
"""
|
||||
import time
|
||||
if self.modulation:
|
||||
if temb is None:
|
||||
raise ValueError("temb must be provided when modulation is enabled")
|
||||
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
OmniGen2 Transformer 2D Model.
|
||||
|
||||
A transformer-based diffusion model for image generation with:
|
||||
- Patch-based image processing
|
||||
- Rotary position embeddings
|
||||
- Multi-head attention
|
||||
- Conditional generation support
|
||||
|
||||
Args:
|
||||
patch_size: Size of image patches
|
||||
in_channels: Number of input channels
|
||||
out_channels: Number of output channels (defaults to in_channels)
|
||||
hidden_size: Size of hidden layers
|
||||
num_layers: Number of transformer layers
|
||||
num_refiner_layers: Number of refiner layers
|
||||
num_attention_heads: Number of attention heads
|
||||
num_kv_heads: Number of key-value heads
|
||||
multiple_of: Multiple of which the hidden dimension should be
|
||||
ffn_dim_multiplier: Multiplier for feed-forward network dimension
|
||||
norm_eps: Epsilon value for normalization layers
|
||||
axes_dim_rope: Dimensions for rotary position embeddings
|
||||
axes_lens: Lengths for rotary position embeddings
|
||||
text_feat_dim: Dimension of text features
|
||||
timestep_scale: Scale factor for timestep embeddings
|
||||
use_fused_rms_norm: Whether to use fused RMS normalization
|
||||
use_fused_swiglu: Whether to use fused SwiGLU activation
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Omnigen2TransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 2304,
|
||||
num_layers: int = 26,
|
||||
num_refiner_layers: int = 2,
|
||||
num_attention_heads: int = 24,
|
||||
num_kv_heads: int = 8,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
||||
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||
text_feat_dim: int = 1024,
|
||||
timestep_scale: float = 1.0
|
||||
) -> None:
|
||||
"""Initialize the OmniGen2 transformer model."""
|
||||
super().__init__()
|
||||
|
||||
# Validate configuration
|
||||
if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
|
||||
raise ValueError(
|
||||
f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
|
||||
f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
|
||||
)
|
||||
|
||||
self.out_channels = out_channels or in_channels
|
||||
|
||||
# Initialize embeddings
|
||||
self.rope_embedder = OmniGen2RotaryPosEmbed(
|
||||
theta=10000,
|
||||
axes_dim=axes_dim_rope,
|
||||
axes_lens=axes_lens,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
self.x_embedder = nn.Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
out_features=hidden_size,
|
||||
)
|
||||
|
||||
self.ref_image_patch_embedder = nn.Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
out_features=hidden_size,
|
||||
)
|
||||
|
||||
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
||||
hidden_size=hidden_size,
|
||||
text_feat_dim=text_feat_dim,
|
||||
norm_eps=norm_eps,
|
||||
timestep_scale=timestep_scale
|
||||
)
|
||||
|
||||
# Initialize transformer blocks
|
||||
self.noise_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.ref_image_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=False
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = LuminaLayerNormContinuous(
|
||||
embedding_dim=hidden_size,
|
||||
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
out_dim=patch_size * patch_size * self.out_channels
|
||||
)
|
||||
|
||||
# Add learnable embeddings to distinguish different images
|
||||
self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self) -> None:
|
||||
"""
|
||||
Initialize the weights of the model.
|
||||
|
||||
Uses Xavier uniform initialization for linear layers.
|
||||
"""
|
||||
nn.init.xavier_uniform_(self.x_embedder.weight)
|
||||
nn.init.constant_(self.x_embedder.bias, 0.0)
|
||||
|
||||
nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
|
||||
nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
|
||||
|
||||
nn.init.zeros_(self.norm_out.linear_1.weight)
|
||||
nn.init.zeros_(self.norm_out.linear_1.bias)
|
||||
nn.init.zeros_(self.norm_out.linear_2.weight)
|
||||
nn.init.zeros_(self.norm_out.linear_2.bias)
|
||||
|
||||
nn.init.normal_(self.image_index_embedding, std=0.02)
|
||||
|
||||
def img_patch_embed_and_refine(
|
||||
self,
|
||||
hidden_states,
|
||||
ref_image_hidden_states,
|
||||
padded_img_mask,
|
||||
padded_ref_img_mask,
|
||||
noise_rotary_emb,
|
||||
ref_img_rotary_emb,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
temb
|
||||
):
|
||||
batch_size = len(hidden_states)
|
||||
max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
||||
|
||||
for i in range(batch_size):
|
||||
shift = 0
|
||||
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
||||
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
|
||||
shift += ref_img_len
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
||||
|
||||
flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
|
||||
num_ref_images = len(flat_l_effective_ref_img_len)
|
||||
max_ref_img_len = max(flat_l_effective_ref_img_len)
|
||||
|
||||
batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
|
||||
batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
|
||||
batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
|
||||
batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
|
||||
|
||||
# sequence of ref imgs to batch
|
||||
idx = 0
|
||||
for i in range(batch_size):
|
||||
shift = 0
|
||||
for ref_img_len in l_effective_ref_img_len[i]:
|
||||
batch_ref_img_mask[idx, :ref_img_len] = True
|
||||
batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
|
||||
batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
|
||||
batch_temb[idx] = temb[i]
|
||||
shift += ref_img_len
|
||||
idx += 1
|
||||
|
||||
# refine ref imgs separately
|
||||
for layer in self.ref_image_refiner:
|
||||
batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
|
||||
|
||||
# batch of ref imgs to sequence
|
||||
idx = 0
|
||||
for i in range(batch_size):
|
||||
shift = 0
|
||||
for ref_img_len in l_effective_ref_img_len[i]:
|
||||
ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
|
||||
shift += ref_img_len
|
||||
idx += 1
|
||||
|
||||
combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
|
||||
for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
|
||||
combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
|
||||
combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
|
||||
|
||||
return combined_img_hidden_states
|
||||
|
||||
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
||||
batch_size = len(hidden_states)
|
||||
p = self.config.patch_size
|
||||
device = hidden_states[0].device
|
||||
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
||||
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
||||
|
||||
if ref_image_hidden_states is not None:
|
||||
ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
|
||||
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
|
||||
else:
|
||||
ref_img_sizes = [None for _ in range(batch_size)]
|
||||
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
||||
|
||||
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
# ref image patch embeddings
|
||||
flat_ref_img_hidden_states = []
|
||||
for i in range(batch_size):
|
||||
if ref_img_sizes[i] is not None:
|
||||
imgs = []
|
||||
for ref_img in ref_image_hidden_states[i]:
|
||||
C, H, W = ref_img.size()
|
||||
ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
imgs.append(ref_img)
|
||||
|
||||
img = torch.cat(imgs, dim=0)
|
||||
flat_ref_img_hidden_states.append(img)
|
||||
else:
|
||||
flat_ref_img_hidden_states.append(None)
|
||||
|
||||
# image patch embeddings
|
||||
flat_hidden_states = []
|
||||
for i in range(batch_size):
|
||||
img = hidden_states[i]
|
||||
C, H, W = img.size()
|
||||
|
||||
img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
flat_hidden_states.append(img)
|
||||
|
||||
padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
|
||||
padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
|
||||
for i in range(batch_size):
|
||||
if ref_img_sizes[i] is not None:
|
||||
padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
|
||||
padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
|
||||
|
||||
padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
|
||||
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
|
||||
for i in range(batch_size):
|
||||
padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
|
||||
padded_img_mask[i, :l_effective_img_len[i]] = True
|
||||
|
||||
return (
|
||||
padded_hidden_states,
|
||||
padded_ref_img_hidden_states,
|
||||
padded_img_mask,
|
||||
padded_ref_img_mask,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
ref_img_sizes,
|
||||
img_sizes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, List[torch.Tensor]],
|
||||
timestep: torch.Tensor,
|
||||
text_hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
text_attention_mask: torch.Tensor,
|
||||
ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = False,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# 1. Condition, positional & patch embedding
|
||||
batch_size = len(hidden_states)
|
||||
is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
|
||||
|
||||
if is_hidden_states_tensor:
|
||||
assert hidden_states.ndim == 4
|
||||
hidden_states = [_hidden_states for _hidden_states in hidden_states]
|
||||
|
||||
device = hidden_states[0].device
|
||||
|
||||
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
||||
|
||||
(
|
||||
hidden_states,
|
||||
ref_image_hidden_states,
|
||||
img_mask,
|
||||
ref_img_mask,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
ref_img_sizes,
|
||||
img_sizes,
|
||||
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
||||
|
||||
(
|
||||
context_rotary_emb,
|
||||
ref_img_rotary_emb,
|
||||
noise_rotary_emb,
|
||||
rotary_emb,
|
||||
encoder_seq_lengths,
|
||||
seq_lengths,
|
||||
) = self.rope_embedder(
|
||||
freqs_cis,
|
||||
text_attention_mask,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
ref_img_sizes,
|
||||
img_sizes,
|
||||
device,
|
||||
)
|
||||
|
||||
# 2. Context refinement
|
||||
for layer in self.context_refiner:
|
||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
||||
|
||||
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||
hidden_states,
|
||||
ref_image_hidden_states,
|
||||
img_mask,
|
||||
ref_img_mask,
|
||||
noise_rotary_emb,
|
||||
ref_img_rotary_emb,
|
||||
l_effective_ref_img_len,
|
||||
l_effective_img_len,
|
||||
temb,
|
||||
)
|
||||
|
||||
# 3. Joint Transformer blocks
|
||||
max_seq_len = int(max(seq_lengths))
|
||||
|
||||
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
||||
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
||||
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
||||
encoder_seq_len = int(encoder_seq_len)
|
||||
seq_len = int(seq_len)
|
||||
attention_mask[i, :seq_len] = True
|
||||
joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
|
||||
joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
|
||||
|
||||
hidden_states = joint_hidden_states
|
||||
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
layer, hidden_states, attention_mask, rotary_emb, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
||||
|
||||
# 4. Output norm & projection
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
p = self.config.patch_size
|
||||
output = []
|
||||
for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
|
||||
img_len = int(img_len)
|
||||
seq_len = int(seq_len)
|
||||
height, width = img_size
|
||||
output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
|
||||
if is_hidden_states_tensor:
|
||||
output = torch.stack(output, dim=0)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return output
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,266 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
|
||||
class OmniGen2ImageProcessor(VaeImageProcessor):
|
||||
"""
|
||||
Image processor for PixArt image resize and crop.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
||||
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
||||
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
||||
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
||||
resample (`str`, *optional*, defaults to `lanczos`):
|
||||
Resampling filter to use when resizing the image.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image to [-1,1].
|
||||
do_binarize (`bool`, *optional*, defaults to `False`):
|
||||
Whether to binarize the image to 0/1.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
||||
Whether to convert the images to RGB format.
|
||||
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
||||
Whether to convert the images to grayscale format.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
vae_scale_factor: int = 16,
|
||||
resample: str = "lanczos",
|
||||
max_pixels: Optional[int] = None,
|
||||
max_side_length: Optional[int] = None,
|
||||
do_normalize: bool = True,
|
||||
do_binarize: bool = False,
|
||||
do_convert_grayscale: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
do_resize=do_resize,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
resample=resample,
|
||||
do_normalize=do_normalize,
|
||||
do_binarize=do_binarize,
|
||||
do_convert_grayscale=do_convert_grayscale,
|
||||
)
|
||||
|
||||
self.max_pixels = max_pixels
|
||||
self.max_side_length = max_side_length
|
||||
|
||||
def get_new_height_width(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
max_side_length: Optional[int] = None,
|
||||
) -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
|
||||
|
||||
Args:
|
||||
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
||||
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
|
||||
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
|
||||
tensor, it should have shape `[batch, channels, height, width]`.
|
||||
height (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
|
||||
width (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`:
|
||||
A tuple containing the height and width, both resized to the nearest integer multiple of
|
||||
`vae_scale_factor`.
|
||||
"""
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[2]
|
||||
else:
|
||||
height = image.shape[1]
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[3]
|
||||
else:
|
||||
width = image.shape[2]
|
||||
|
||||
if max_side_length is None:
|
||||
max_side_length = self.max_side_length
|
||||
|
||||
if max_pixels is None:
|
||||
max_pixels = self.max_pixels
|
||||
|
||||
ratio = 1.0
|
||||
if max_side_length is not None:
|
||||
if height > width:
|
||||
max_side_length_ratio = max_side_length / height
|
||||
else:
|
||||
max_side_length_ratio = max_side_length / width
|
||||
|
||||
cur_pixels = height * width
|
||||
max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5
|
||||
ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image
|
||||
|
||||
new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor
|
||||
return new_height, new_width
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
max_side_length: Optional[int] = None,
|
||||
resize_mode: str = "default", # "default", "fill", "crop"
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess the image input.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
||||
supported formats.
|
||||
height (`int`, *optional*):
|
||||
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
||||
height.
|
||||
width (`int`, *optional*):
|
||||
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
||||
resize_mode (`str`, *optional*, defaults to `default`):
|
||||
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
||||
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
||||
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
||||
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
||||
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
||||
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
||||
supported for PIL image input.
|
||||
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The preprocessed image.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
|
||||
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
||||
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
||||
if isinstance(image, torch.Tensor):
|
||||
# if image is a pytorch tensor could have 2 possible shapes:
|
||||
# 1. batch x height x width: we should insert the channel dimension at position 1
|
||||
# 2. channel x height x width: we should insert batch dimension at position 0,
|
||||
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
||||
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
||||
image = image.unsqueeze(1)
|
||||
else:
|
||||
# if it is a numpy array, it could have 2 possible shapes:
|
||||
# 1. batch x height x width: insert channel dimension on last position
|
||||
# 2. height x width x channel: insert batch dimension on first position
|
||||
if image.shape[-1] == 1:
|
||||
image = np.expand_dims(image, axis=0)
|
||||
else:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
||||
warnings.warn(
|
||||
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
||||
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
||||
FutureWarning,
|
||||
)
|
||||
image = np.concatenate(image, axis=0)
|
||||
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
||||
warnings.warn(
|
||||
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
||||
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
||||
FutureWarning,
|
||||
)
|
||||
image = torch.cat(image, axis=0)
|
||||
|
||||
if not is_valid_image_imagelist(image):
|
||||
raise ValueError(
|
||||
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
|
||||
)
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
if crops_coords is not None:
|
||||
image = [i.crop(crops_coords) for i in image]
|
||||
if self.config.do_resize:
|
||||
height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
|
||||
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
||||
if self.config.do_convert_rgb:
|
||||
image = [self.convert_to_rgb(i) for i in image]
|
||||
elif self.config.do_convert_grayscale:
|
||||
image = [self.convert_to_grayscale(i) for i in image]
|
||||
image = self.pil_to_numpy(image) # to np
|
||||
image = self.numpy_to_pt(image) # to pt
|
||||
|
||||
elif isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
||||
|
||||
image = self.numpy_to_pt(image)
|
||||
|
||||
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
||||
if self.config.do_resize:
|
||||
image = self.resize(image, height, width)
|
||||
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
||||
|
||||
if self.config.do_convert_grayscale and image.ndim == 3:
|
||||
image = image.unsqueeze(1)
|
||||
|
||||
channel = image.shape[1]
|
||||
# don't need any preprocess if the image is latents
|
||||
if channel == self.config.vae_latent_channels:
|
||||
return image
|
||||
|
||||
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
||||
if self.config.do_resize:
|
||||
image = self.resize(image, height, width)
|
||||
|
||||
# expected range [0,1], normalize to [-1,1]
|
||||
do_normalize = self.config.do_normalize
|
||||
if do_normalize and image.min() < 0:
|
||||
warnings.warn(
|
||||
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
||||
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
||||
FutureWarning,
|
||||
)
|
||||
do_normalize = False
|
||||
if do_normalize:
|
||||
image = self.normalize(image)
|
||||
|
||||
if self.config.do_binarize:
|
||||
image = self.binarize(image)
|
||||
|
||||
return image
|
||||
@@ -0,0 +1,728 @@
|
||||
"""
|
||||
OmniGen2 Diffusion Pipeline
|
||||
|
||||
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from diffusers.models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import OmniGen2Transformer2DModel
|
||||
from ...models.transformers.repo import OmniGen2RotaryPosEmbed
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import PIL.Image
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
from ....src.pipelines.image_processor import OmniGen2ImageProcessor
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class FMPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for OmniGen2 pipeline.
|
||||
|
||||
Args:
|
||||
images (Union[List[PIL.Image.Image], np.ndarray]):
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape
|
||||
`(batch_size, height, width, num_channels)`. Contains the generated images.
|
||||
"""
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class OmniGen2Pipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-to-image generation using OmniGen2.
|
||||
|
||||
This pipeline implements a text-to-image generation model that uses:
|
||||
- Qwen2.5-VL for text encoding
|
||||
- A custom transformer architecture for image generation
|
||||
- VAE for image encoding/decoding
|
||||
- FlowMatchEulerDiscreteScheduler for noise scheduling
|
||||
|
||||
Args:
|
||||
transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
|
||||
vae (AutoencoderKL): The VAE model for image encoding/decoding.
|
||||
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
|
||||
text_encoder (Qwen2_5_VLModel): The text encoder model.
|
||||
tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "mllm->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: OmniGen2Transformer2DModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
mllm: Qwen2_5_VLForConditionalGeneration,
|
||||
processor,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the OmniGen2 pipeline.
|
||||
|
||||
Args:
|
||||
transformer: The transformer model for image generation.
|
||||
vae: The VAE model for image encoding/decoding.
|
||||
scheduler: The scheduler for noise scheduling.
|
||||
text_encoder: The text encoder model.
|
||||
tokenizer: The tokenizer for text processing.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
mllm=mllm,
|
||||
processor=processor
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
|
||||
self.default_sample_size = 128
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[torch.Generator],
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Prepare the initial latents for the diffusion process.
|
||||
|
||||
Args:
|
||||
batch_size: The number of images to generate.
|
||||
num_channels_latents: The number of channels in the latent space.
|
||||
height: The height of the generated image.
|
||||
width: The width of the generated image.
|
||||
dtype: The data type of the latents.
|
||||
device: The device to place the latents on.
|
||||
generator: The random number generator to use.
|
||||
latents: Optional pre-computed latents to use instead of random initialization.
|
||||
|
||||
Returns:
|
||||
torch.FloatTensor: The prepared latents tensor.
|
||||
"""
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
Encode an image into the VAE latent space.
|
||||
|
||||
Args:
|
||||
img: The input image tensor to encode.
|
||||
|
||||
Returns:
|
||||
torch.FloatTensor: The encoded latent representation.
|
||||
"""
|
||||
z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
|
||||
if self.vae.config.shift_factor is not None:
|
||||
z0 = z0 - self.vae.config.shift_factor
|
||||
if self.vae.config.scaling_factor is not None:
|
||||
z0 = z0 * self.vae.config.scaling_factor
|
||||
z0 = z0.to(dtype=self.vae.dtype)
|
||||
return z0
|
||||
|
||||
def prepare_image(
|
||||
self,
|
||||
images: Union[List[PIL.Image.Image], PIL.Image.Image],
|
||||
batch_size: int,
|
||||
num_images_per_prompt: int,
|
||||
max_pixels: int,
|
||||
max_side_length: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> List[Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Prepare input images for processing by encoding them into the VAE latent space.
|
||||
|
||||
Args:
|
||||
images: Single image or list of images to process.
|
||||
batch_size: The number of images to generate per prompt.
|
||||
num_images_per_prompt: The number of images to generate for each prompt.
|
||||
device: The device to place the encoded latents on.
|
||||
dtype: The data type of the encoded latents.
|
||||
|
||||
Returns:
|
||||
List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
|
||||
"""
|
||||
if batch_size == 1:
|
||||
images = [images]
|
||||
latents = []
|
||||
for i, img in enumerate(images):
|
||||
if img is not None and len(img) > 0:
|
||||
ref_latents = []
|
||||
for j, img_j in enumerate(img):
|
||||
img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
|
||||
ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
|
||||
else:
|
||||
ref_latents = None
|
||||
for _ in range(num_images_per_prompt):
|
||||
latents.append(ref_latents)
|
||||
|
||||
return latents
|
||||
|
||||
def _get_qwen2_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 256,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get prompt embeddings from the Qwen2 text encoder.
|
||||
|
||||
Args:
|
||||
prompt: The prompt or list of prompts to encode.
|
||||
device: The device to place the embeddings on. If None, uses the pipeline's device.
|
||||
max_sequence_length: Maximum sequence length for tokenization.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The prompt embeddings tensor
|
||||
- The attention mask tensor
|
||||
|
||||
Raises:
|
||||
Warning: If the input text is truncated due to sequence length limitations.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
# text_inputs = self.processor.tokenizer(
|
||||
# prompt,
|
||||
# padding="max_length",
|
||||
# max_length=max_sequence_length,
|
||||
# truncation=True,
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
text_inputs = self.processor.tokenizer(
|
||||
prompt,
|
||||
padding="longest",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because Gemma can only handle sequences up to"
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds = self.mllm(
|
||||
text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-1]
|
||||
|
||||
if self.mllm is not None:
|
||||
dtype = self.mllm.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def _apply_chat_template(self, prompt: str):
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that generates high-quality images based on user instructions.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
|
||||
return prompt
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 256,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
Lumina-T2I, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
|
||||
max_sequence_length (`int`, defaults to `256`):
|
||||
Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is not None:
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [self._apply_chat_template(_prompt) for _prompt in prompt]
|
||||
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
# Get negative embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
||||
|
||||
# Normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = negative_prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
||||
batch_size * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def text_guidance_scale(self):
|
||||
return self._text_guidance_scale
|
||||
|
||||
@property
|
||||
def image_guidance_scale(self):
|
||||
return self._image_guidance_scale
|
||||
|
||||
@property
|
||||
def cfg_range(self):
|
||||
return self._cfg_range
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.LongTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
|
||||
max_sequence_length: Optional[int] = None,
|
||||
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
||||
input_images: Optional[List[PIL.Image.Image]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
max_pixels: int = 1024 * 1024,
|
||||
max_input_image_side_length: int = 1024,
|
||||
align_res: bool = True,
|
||||
num_inference_steps: int = 28,
|
||||
text_guidance_scale: float = 4.0,
|
||||
image_guidance_scale: float = 1.0,
|
||||
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
timesteps: List[int] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
verbose: bool = False,
|
||||
step_func=None,
|
||||
):
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
self._text_guidance_scale = text_guidance_scale
|
||||
self._image_guidance_scale = image_guidance_scale
|
||||
self._cfg_range = cfg_range
|
||||
self._attention_kwargs = attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
self.text_guidance_scale > 1.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
dtype = self.vae.dtype
|
||||
# 3. Prepare control image
|
||||
ref_latents = self.prepare_image(
|
||||
images=input_images,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_pixels=max_pixels,
|
||||
max_side_length=max_input_image_side_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if input_images is None:
|
||||
input_images = []
|
||||
|
||||
if len(input_images) == 1 and align_res:
|
||||
width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
|
||||
ori_width, ori_height = width, height
|
||||
else:
|
||||
ori_width, ori_height = width, height
|
||||
|
||||
cur_pixels = height * width
|
||||
ratio = (max_pixels / cur_pixels) ** 0.5
|
||||
ratio = min(ratio, 1.0)
|
||||
|
||||
height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
|
||||
|
||||
if len(input_images) == 0:
|
||||
self._image_guidance_scale = 1
|
||||
|
||||
# 4. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
|
||||
self.transformer.config.axes_dim_rope,
|
||||
self.transformer.config.axes_lens,
|
||||
theta=10000,
|
||||
)
|
||||
|
||||
image = self.processing(
|
||||
latents=latents,
|
||||
ref_latents=ref_latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
verbose=verbose,
|
||||
step_func=step_func,
|
||||
)
|
||||
|
||||
image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return image
|
||||
else:
|
||||
return FMPipelineOutput(images=image)
|
||||
|
||||
def processing(
|
||||
self,
|
||||
latents,
|
||||
ref_latents,
|
||||
prompt_embeds,
|
||||
freqs_cis,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
num_inference_steps,
|
||||
timesteps,
|
||||
device,
|
||||
dtype,
|
||||
verbose,
|
||||
step_func=None
|
||||
):
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
num_tokens=latents.shape[-2] * latents.shape[-1]
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
model_pred = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
ref_image_hidden_states=ref_latents,
|
||||
)
|
||||
text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
||||
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
||||
|
||||
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
|
||||
model_pred_ref = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=negative_prompt_attention_mask,
|
||||
ref_image_hidden_states=ref_latents,
|
||||
)
|
||||
|
||||
if image_guidance_scale != 1:
|
||||
model_pred_uncond = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=negative_prompt_attention_mask,
|
||||
ref_image_hidden_states=None,
|
||||
)
|
||||
else:
|
||||
model_pred_uncond = torch.zeros_like(model_pred)
|
||||
|
||||
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
|
||||
text_guidance_scale * (model_pred - model_pred_ref)
|
||||
elif text_guidance_scale > 1.0:
|
||||
model_pred_uncond = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=negative_prompt_attention_mask,
|
||||
ref_image_hidden_states=None,
|
||||
)
|
||||
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
latents = latents.to(dtype=dtype)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if step_func is not None:
|
||||
step_func(i, self._num_timesteps)
|
||||
|
||||
latents = latents.to(dtype=dtype)
|
||||
if self.vae.config.scaling_factor is not None:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
if self.vae.config.shift_factor is not None:
|
||||
latents = latents + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
return image
|
||||
|
||||
def predict(
|
||||
self,
|
||||
t,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
freqs_cis,
|
||||
prompt_attention_mask,
|
||||
ref_image_hidden_states,
|
||||
):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
|
||||
optional_kwargs = {}
|
||||
if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
|
||||
optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
|
||||
|
||||
model_pred = self.transformer(
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
freqs_cis,
|
||||
prompt_attention_mask,
|
||||
**optional_kwargs
|
||||
)
|
||||
return model_pred
|
||||
@@ -0,0 +1,830 @@
|
||||
"""
|
||||
OmniGen2 Diffusion Pipeline
|
||||
|
||||
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from diffusers.models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import OmniGen2Transformer2DModel
|
||||
from ...models.transformers.repo import OmniGen2RotaryPosEmbed
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import PIL.Image
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
from src.pipelines.image_processor import OmniGen2ImageProcessor
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class OmniGen2PipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for OmniGen2 pipeline.
|
||||
|
||||
Args:
|
||||
images (Union[List[PIL.Image.Image], np.ndarray]):
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape
|
||||
`(batch_size, height, width, num_channels)`. Contains the generated images.
|
||||
"""
|
||||
text: str
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class OmniGen2ChatPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-to-image generation using OmniGen2.
|
||||
|
||||
This pipeline implements a text-to-image generation model that uses:
|
||||
- Qwen2.5-VL for text encoding
|
||||
- A custom transformer architecture for image generation
|
||||
- VAE for image encoding/decoding
|
||||
- FlowMatchEulerDiscreteScheduler for noise scheduling
|
||||
|
||||
Args:
|
||||
transformer (OmniGen2Transformer2DModel): The transformer model for image generation.
|
||||
vae (AutoencoderKL): The VAE model for image encoding/decoding.
|
||||
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling.
|
||||
text_encoder (Qwen2_5_VLModel): The text encoder model.
|
||||
tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "mllm->transformer->vae"
|
||||
def __init__(
|
||||
self,
|
||||
transformer: OmniGen2Transformer2DModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
mllm: Qwen2_5_VLForConditionalGeneration,
|
||||
processor,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the OmniGen2 pipeline.
|
||||
|
||||
Args:
|
||||
transformer: The transformer model for image generation.
|
||||
vae: The VAE model for image encoding/decoding.
|
||||
scheduler: The scheduler for noise scheduling.
|
||||
text_encoder: The text encoder model.
|
||||
tokenizer: The tokenizer for text processing.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
mllm=mllm,
|
||||
processor=processor
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True)
|
||||
self.default_sample_size = 128
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[torch.Generator],
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Prepare the initial latents for the diffusion process.
|
||||
|
||||
Args:
|
||||
batch_size: The number of images to generate.
|
||||
num_channels_latents: The number of channels in the latent space.
|
||||
height: The height of the generated image.
|
||||
width: The width of the generated image.
|
||||
dtype: The data type of the latents.
|
||||
device: The device to place the latents on.
|
||||
generator: The random number generator to use.
|
||||
latents: Optional pre-computed latents to use instead of random initialization.
|
||||
|
||||
Returns:
|
||||
torch.FloatTensor: The prepared latents tensor.
|
||||
"""
|
||||
height = int(height) // self.vae_scale_factor
|
||||
width = int(width) // self.vae_scale_factor
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
Encode an image into the VAE latent space.
|
||||
|
||||
Args:
|
||||
img: The input image tensor to encode.
|
||||
|
||||
Returns:
|
||||
torch.FloatTensor: The encoded latent representation.
|
||||
"""
|
||||
z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample()
|
||||
if self.vae.config.shift_factor is not None:
|
||||
z0 = z0 - self.vae.config.shift_factor
|
||||
if self.vae.config.scaling_factor is not None:
|
||||
z0 = z0 * self.vae.config.scaling_factor
|
||||
z0 = z0.to(dtype=self.vae.dtype)
|
||||
return z0
|
||||
|
||||
def prepare_image(
|
||||
self,
|
||||
images: Union[List[PIL.Image.Image], PIL.Image.Image],
|
||||
batch_size: int,
|
||||
num_images_per_prompt: int,
|
||||
max_pixels: int,
|
||||
max_side_length: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> List[Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Prepare input images for processing by encoding them into the VAE latent space.
|
||||
|
||||
Args:
|
||||
images: Single image or list of images to process.
|
||||
batch_size: The number of images to generate per prompt.
|
||||
num_images_per_prompt: The number of images to generate for each prompt.
|
||||
device: The device to place the encoded latents on.
|
||||
dtype: The data type of the encoded latents.
|
||||
|
||||
Returns:
|
||||
List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image.
|
||||
"""
|
||||
if batch_size == 1:
|
||||
images = [images]
|
||||
latents = []
|
||||
for i, img in enumerate(images):
|
||||
if img is not None and len(img) > 0:
|
||||
ref_latents = []
|
||||
for j, img_j in enumerate(img):
|
||||
img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length)
|
||||
ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0))
|
||||
else:
|
||||
ref_latents = None
|
||||
for _ in range(num_images_per_prompt):
|
||||
latents.append(ref_latents)
|
||||
|
||||
return latents
|
||||
|
||||
def _apply_chat_template(self, prompt: str, images: List = None):
|
||||
if images is not None:
|
||||
prompt = "".join(
|
||||
[
|
||||
f"<img{i}>: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
for i in range(1, len(images) + 1)
|
||||
]
|
||||
) + prompt
|
||||
prompt = f"<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
return prompt
|
||||
|
||||
def _get_qwen2_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
input_images = None,
|
||||
device: Optional[torch.device] = None,
|
||||
use_only_text_hidden_states: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get prompt embeddings from the Qwen2 text encoder.
|
||||
|
||||
Args:
|
||||
prompt: The prompt or list of prompts to encode.
|
||||
device: The device to place the embeddings on. If None, uses the pipeline's device.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The prompt embeddings tensor
|
||||
- The attention mask tensor
|
||||
|
||||
Raises:
|
||||
Warning: If the input text is truncated due to sequence length limitations.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
inputs = self.processor(
|
||||
text=prompt,
|
||||
images=input_images,
|
||||
videos=None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
prompt_embeds = self.mllm(
|
||||
**inputs,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-1]
|
||||
|
||||
text_input_ids = inputs.input_ids
|
||||
text_mask = inputs.attention_mask
|
||||
if use_only_text_hidden_states:
|
||||
mask = text_input_ids != self.mllm.config.image_token_id
|
||||
mask = mask & text_mask
|
||||
mask = mask.bool()
|
||||
|
||||
text_l = mask.sum(dim=-1)
|
||||
max_l = text_l.max()
|
||||
text_batch_size = prompt_embeds.size(0)
|
||||
new_prompt_embeds = torch.zeros((text_batch_size, max_l, prompt_embeds.size(-1)), device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
new_text_mask = torch.zeros((text_batch_size, max_l), dtype=text_mask.dtype, device=text_mask.device)
|
||||
for i in range(text_batch_size):
|
||||
new_prompt_embeds[i, :text_l[i]] = prompt_embeds[i, mask[i]]
|
||||
new_text_mask[i, :text_l[i]] = 1
|
||||
|
||||
prompt_embeds = new_prompt_embeds
|
||||
text_mask = new_text_mask
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.mllm.dtype, device=device)
|
||||
return prompt_embeds, text_mask
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
input_images: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 256,
|
||||
use_text_encoder_penultimate_layer_feats: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
Lumina-T2I, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
|
||||
max_sequence_length (`int`, defaults to `256`):
|
||||
Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
||||
prompt=prompt,
|
||||
input_images=input_images,
|
||||
device=device,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
# Get negative embeddings for classifier free guidance
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = None, None
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
||||
|
||||
# Normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt]
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = negative_prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
||||
batch_size * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def text_guidance_scale(self):
|
||||
return self._text_guidance_scale
|
||||
|
||||
@property
|
||||
def image_guidance_scale(self):
|
||||
return self._image_guidance_scale
|
||||
|
||||
@property
|
||||
def cfg_range(self):
|
||||
return self._cfg_range
|
||||
|
||||
def prepare_inputs_for_text_generation(self, prompts, input_images, device):
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
ori_padding_side = self.processor.tokenizer.padding_side
|
||||
self.processor.tokenizer.padding_side = "left"
|
||||
inputs = self.processor(
|
||||
text=prompts,
|
||||
images=input_images,
|
||||
videos=None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
self.processor.tokenizer.padding_side = ori_padding_side
|
||||
return inputs
|
||||
|
||||
def generate_text(self, prompt, input_images):
|
||||
inputs = self.prepare_inputs_for_text_generation(
|
||||
prompt, input_images, self.mllm.device
|
||||
)
|
||||
generated_ids = self.mllm.generate(
|
||||
**inputs,
|
||||
tokenizer=self.processor.tokenizer,
|
||||
max_new_tokens=256,
|
||||
stop_strings=["<|im_end|>", "<|img|>", "<|endoftext|>"],
|
||||
) # stop_words=[151643, 151645, 151665]
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :]
|
||||
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_texts = self.processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
# skip_special_tokens=True,
|
||||
skip_special_tokens=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
return output_texts
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.LongTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
|
||||
use_text_encoder_penultimate_layer_feats: bool = False,
|
||||
max_sequence_length: Optional[int] = None,
|
||||
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
||||
input_images: Optional[List[PIL.Image.Image]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
max_pixels: int = 1024 * 1024,
|
||||
max_input_image_side_length: int = 1024,
|
||||
align_res: bool = True,
|
||||
num_inference_steps: int = 28,
|
||||
text_guidance_scale: float = 4.0,
|
||||
image_guidance_scale: float = 1.0,
|
||||
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
timesteps: List[int] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
verbose: bool = False,
|
||||
step_func=None,
|
||||
):
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
self._text_guidance_scale = text_guidance_scale
|
||||
self._image_guidance_scale = image_guidance_scale
|
||||
self._cfg_range = cfg_range
|
||||
self._attention_kwargs = attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode input promptb
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
input_images,
|
||||
self.text_guidance_scale > 1.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats
|
||||
)
|
||||
|
||||
dtype = self.vae.dtype
|
||||
# 3. Prepare control image
|
||||
ref_latents = self.prepare_image(
|
||||
images=input_images,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_pixels=max_pixels,
|
||||
max_side_length=max_input_image_side_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if input_images is None:
|
||||
input_images = []
|
||||
|
||||
if len(input_images) == 1 and align_res:
|
||||
width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor
|
||||
ori_width, ori_height = width, height
|
||||
else:
|
||||
ori_width, ori_height = width, height
|
||||
|
||||
cur_pixels = height * width
|
||||
ratio = (max_pixels / cur_pixels) ** 0.5
|
||||
ratio = min(ratio, 1.0)
|
||||
|
||||
height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16
|
||||
|
||||
if len(input_images) == 0:
|
||||
self._image_guidance_scale = 1
|
||||
|
||||
# 4. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
|
||||
self.transformer.config.axes_dim_rope,
|
||||
self.transformer.config.axes_lens,
|
||||
theta=10000,
|
||||
)
|
||||
|
||||
image = self.processing(
|
||||
latents=latents,
|
||||
ref_latents=ref_latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
verbose=verbose,
|
||||
step_func=step_func,
|
||||
)
|
||||
|
||||
image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.LongTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.LongTensor] = None,
|
||||
use_text_encoder_penultimate_layer_feats: bool = False,
|
||||
max_sequence_length: Optional[int] = None,
|
||||
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
||||
input_images: Optional[List[PIL.Image.Image]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
height: Optional[int] = 1024,
|
||||
width: Optional[int] = 1024,
|
||||
max_pixels: Optional[int] = 1024 * 1024,
|
||||
max_input_image_side_length: int = 1024,
|
||||
align_res: bool = True,
|
||||
num_inference_steps: int = 28,
|
||||
text_guidance_scale: float = 4.0,
|
||||
image_guidance_scale: float = 1.0,
|
||||
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
timesteps: List[int] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
verbose: bool = False,
|
||||
step_func=None,
|
||||
):
|
||||
assert isinstance(prompt, str), "prompt must be a string since chat mode only support one prompt per turn"
|
||||
|
||||
# input_images = self.preprocess_images(input_images, max_input_image_size)
|
||||
prompt = self._apply_chat_template(prompt, input_images)
|
||||
generated_text = self.generate_text(prompt, input_images)[0]
|
||||
|
||||
images = None
|
||||
if generated_text.startswith("<|img|>"):
|
||||
#TODO: reuse the hidden state when generate text instead of re-generating
|
||||
prompt = prompt + generated_text.split("<|img|>")[0]
|
||||
images = self.generate_image(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats,
|
||||
max_sequence_length=max_sequence_length,
|
||||
input_images=input_images,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
max_pixels=max_pixels,
|
||||
max_input_image_side_length=max_input_image_side_length,
|
||||
align_res=align_res,
|
||||
num_inference_steps=num_inference_steps,
|
||||
text_guidance_scale=text_guidance_scale,
|
||||
image_guidance_scale=image_guidance_scale,
|
||||
cfg_range=cfg_range,
|
||||
timesteps=timesteps,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
return_dict=False,
|
||||
verbose=verbose,
|
||||
step_func=step_func,
|
||||
)
|
||||
|
||||
generated_text = generated_text.replace("<|im_end|>", "")
|
||||
if not return_dict:
|
||||
return generated_text, images
|
||||
else:
|
||||
return OmniGen2PipelineOutput(text=generated_text, images=images)
|
||||
|
||||
def processing(
|
||||
self,
|
||||
latents,
|
||||
ref_latents,
|
||||
prompt_embeds,
|
||||
freqs_cis,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
num_inference_steps,
|
||||
timesteps,
|
||||
device,
|
||||
dtype,
|
||||
verbose,
|
||||
step_func=None
|
||||
):
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
num_tokens=latents.shape[-2] * latents.shape[-1]
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
model_pred = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
ref_image_hidden_states=ref_latents,
|
||||
)
|
||||
|
||||
text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
||||
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
|
||||
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
|
||||
model_pred_ref = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=negative_prompt_attention_mask,
|
||||
ref_image_hidden_states=ref_latents,
|
||||
)
|
||||
|
||||
if image_guidance_scale != 1:
|
||||
model_pred_uncond = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=negative_prompt_attention_mask,
|
||||
ref_image_hidden_states=None,
|
||||
)
|
||||
else:
|
||||
model_pred_uncond = torch.zeros_like(model_pred)
|
||||
|
||||
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
|
||||
text_guidance_scale * (model_pred - model_pred_ref)
|
||||
elif text_guidance_scale > 1.0:
|
||||
model_pred_uncond = self.predict(
|
||||
t=t,
|
||||
latents=latents,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=negative_prompt_attention_mask,
|
||||
ref_image_hidden_states=None,
|
||||
)
|
||||
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
latents = latents.to(dtype=dtype)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if step_func is not None:
|
||||
step_func(i, self._num_timesteps)
|
||||
|
||||
latents = latents.to(dtype=dtype)
|
||||
if self.vae.config.scaling_factor is not None:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
if self.vae.config.shift_factor is not None:
|
||||
latents = latents + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
return image
|
||||
|
||||
def predict(
|
||||
self,
|
||||
t,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
freqs_cis,
|
||||
prompt_attention_mask,
|
||||
ref_image_hidden_states,
|
||||
):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
|
||||
optional_kwargs = {}
|
||||
if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()):
|
||||
optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
|
||||
|
||||
model_pred = self.transformer(
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
freqs_cis,
|
||||
prompt_attention_mask,
|
||||
**optional_kwargs
|
||||
)
|
||||
return model_pred
|
||||
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
|
||||
""" Get pipeline embeds for prompts bigger than the maxlength of the pipe
|
||||
:param pipeline:
|
||||
:param prompt:
|
||||
:param negative_prompt:
|
||||
:param device:
|
||||
:return:
|
||||
"""
|
||||
max_length = pipeline.tokenizer.model_max_length
|
||||
|
||||
# simple way to determine length of tokens
|
||||
# count_prompt = len(prompt.split(" "))
|
||||
# count_negative_prompt = len(negative_prompt.split(" "))
|
||||
|
||||
# create the tensor based on which prompt is longer
|
||||
# if count_prompt >= count_negative_prompt:
|
||||
input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device)
|
||||
# input_ids = pipeline.tokenizer(prompt, padding="max_length",
|
||||
# max_length=pipeline.tokenizer.model_max_length,
|
||||
# truncation=True,
|
||||
# return_tensors="pt",).input_ids.to(device)
|
||||
shape_max_length = input_ids.shape[-1]
|
||||
|
||||
if negative_prompt is not None:
|
||||
negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length",
|
||||
max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
|
||||
|
||||
# else:
|
||||
# negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
|
||||
# shape_max_length = negative_ids.shape[-1]
|
||||
# input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
|
||||
# max_length=shape_max_length).input_ids.to(device)
|
||||
|
||||
concat_embeds = []
|
||||
neg_embeds = []
|
||||
for i in range(0, shape_max_length, max_length):
|
||||
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
|
||||
attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length],
|
||||
attention_mask=attention_mask)[0])
|
||||
|
||||
if negative_prompt is not None:
|
||||
if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask:
|
||||
attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length],
|
||||
attention_mask=attention_mask)[0])
|
||||
|
||||
concat_embeds = torch.cat(concat_embeds, dim=1)
|
||||
|
||||
if negative_prompt is not None:
|
||||
neg_embeds = torch.cat(neg_embeds, dim=1)
|
||||
else:
|
||||
neg_embeds = None
|
||||
|
||||
return concat_embeds, neg_embeds
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,229 @@
|
||||
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Euler scheduler.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
shift (`float`, defaults to 1.0):
|
||||
The shift value for the timestep schedule.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
dynamic_time_shift: bool = True
|
||||
):
|
||||
timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
|
||||
|
||||
self.timesteps = timesteps
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self._timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
# return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
timesteps: Optional[List[float]] = None,
|
||||
num_tokens: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
|
||||
if timesteps is None:
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
|
||||
if self.config.dynamic_time_shift and num_tokens is not None:
|
||||
m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
|
||||
timesteps = timesteps / (m - m * timesteps + timesteps)
|
||||
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
|
||||
_timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
|
||||
|
||||
self.timesteps = timesteps
|
||||
self._timesteps = _timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
s_churn (`float`):
|
||||
s_tmin (`float`):
|
||||
s_tmax (`float`):
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
t = self._timesteps[self.step_index]
|
||||
t_next = self._timesteps[self.step_index + 1]
|
||||
|
||||
prev_sample = sample + (t_next - t) * model_output
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -0,0 +1,31 @@
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
def resize_image(image, max_pixels, img_scale_num):
|
||||
width, height = image.size
|
||||
cur_pixels = height * width
|
||||
ratio = (max_pixels / cur_pixels) ** 0.5
|
||||
ratio = min(ratio, 1.0) # do not upscale input image
|
||||
|
||||
new_height, new_width = int(height * ratio) // img_scale_num * img_scale_num, int(width * ratio) // img_scale_num * img_scale_num
|
||||
|
||||
image = image.resize((new_width, new_height), resample=Image.BICUBIC)
|
||||
return image
|
||||
|
||||
def create_collage(images: List[torch.Tensor]) -> Image.Image:
|
||||
"""Create a horizontal collage from a list of images."""
|
||||
max_height = max(img.shape[-2] for img in images)
|
||||
total_width = sum(img.shape[-1] for img in images)
|
||||
canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
|
||||
|
||||
current_x = 0
|
||||
for img in images:
|
||||
h, w = img.shape[-2:]
|
||||
canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
|
||||
current_x += w
|
||||
|
||||
return to_pil_image(canvas)
|
||||
@@ -0,0 +1,46 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Import utilities: Utilities related to imports and our lazy inits.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
# The package importlib_metadata is in a different place, depending on the python version.
|
||||
if sys.version_info < (3, 8):
|
||||
import importlib_metadata
|
||||
else:
|
||||
import importlib.metadata as importlib_metadata
|
||||
|
||||
def _is_package_available(pkg_name: str):
|
||||
pkg_exists = importlib.util.find_spec(pkg_name) is not None
|
||||
pkg_version = "N/A"
|
||||
|
||||
if pkg_exists:
|
||||
try:
|
||||
pkg_version = importlib_metadata.version(pkg_name)
|
||||
except (ImportError, importlib_metadata.PackageNotFoundError):
|
||||
pkg_exists = False
|
||||
|
||||
return pkg_exists, pkg_version
|
||||
|
||||
_triton_available, _triton_version = _is_package_available("triton")
|
||||
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
||||
|
||||
def is_triton_available():
|
||||
return _triton_available
|
||||
|
||||
def is_flash_attn_available():
|
||||
return _flash_attn_available
|
||||
@@ -34,7 +34,7 @@ def get_all_extensions() -> List[Extension]:
|
||||
for sub_dir in extension_folders:
|
||||
extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir)
|
||||
for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
|
||||
try:
|
||||
# try:
|
||||
# Import the module
|
||||
module = importlib.import_module(f"{sub_dir}.{name}")
|
||||
# Get the value of the AI_TOOLKIT_EXTENSIONS variable
|
||||
@@ -43,8 +43,8 @@ def get_all_extensions() -> List[Extension]:
|
||||
if isinstance(extensions, list):
|
||||
# Iterate over the list and add the classes to the main list
|
||||
all_extension_classes.extend(extensions)
|
||||
except ImportError as e:
|
||||
print(f"Failed to import the {name} module. Error: {str(e)}")
|
||||
# except ImportError as e:
|
||||
# print(f"Failed to import the {name} module. Error: {str(e)}")
|
||||
|
||||
return all_extension_classes
|
||||
|
||||
|
||||
@@ -170,6 +170,19 @@ export const modelArchs: ModelArch[] = [
|
||||
},
|
||||
disableSections: ['model.quantize', 'train.timestep_type'],
|
||||
},
|
||||
{
|
||||
name: 'omnigen2',
|
||||
label: 'OmniGen2',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].model.quantize': [false, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
},
|
||||
].sort((a, b) => {
|
||||
// Sort by label, case-insensitive
|
||||
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.3.1"
|
||||
VERSION = "0.3.2"
|
||||
Reference in New Issue
Block a user