mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-27 17:09:47 +00:00
364 lines
13 KiB
Python
364 lines
13 KiB
Python
import os
|
|
from typing import TYPE_CHECKING, List, Optional
|
|
|
|
import torch
|
|
import yaml
|
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
|
from toolkit.models.base_model import BaseModel
|
|
from diffusers import AutoencoderKL
|
|
from toolkit.basic import flush
|
|
from toolkit.prompt_utils import PromptEmbeds
|
|
from toolkit.samplers.custom_flowmatch_sampler import (
|
|
CustomFlowMatchEulerDiscreteScheduler,
|
|
)
|
|
from toolkit.accelerator import unwrap_model
|
|
from optimum.quanto import freeze
|
|
from toolkit.util.quantize import quantize, get_qtype
|
|
from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
|
|
from .src.models.transformers import OmniGen2Transformer2DModel
|
|
from .src.models.transformers.repo import OmniGen2RotaryPosEmbed
|
|
from .src.schedulers.scheduling_flow_match_euler_discrete import (
|
|
FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler,
|
|
)
|
|
from PIL import Image
|
|
from transformers import (
|
|
CLIPProcessor,
|
|
Qwen2_5_VLForConditionalGeneration,
|
|
)
|
|
import torch.nn.functional as F
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
|
|
|
scheduler_config = {"num_train_timesteps": 1000}
|
|
|
|
BASE_MODEL_PATH = "OmniGen2/OmniGen2"
|
|
|
|
|
|
class OmniGen2Model(BaseModel):
|
|
arch = "omnigen2"
|
|
|
|
def __init__(
|
|
self,
|
|
device,
|
|
model_config: ModelConfig,
|
|
dtype="bf16",
|
|
custom_pipeline=None,
|
|
noise_scheduler=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
|
)
|
|
self.is_flow_matching = True
|
|
self.is_transformer = True
|
|
self.target_lora_modules = ["OmniGen2Transformer2DModel"]
|
|
self._control_latent = None
|
|
|
|
# static method to get the noise scheduler
|
|
@staticmethod
|
|
def get_train_scheduler():
|
|
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
|
|
|
def get_bucket_divisibility(self):
|
|
return 16
|
|
|
|
def load_model(self):
|
|
dtype = self.torch_dtype
|
|
# HiDream-ai/HiDream-I1-Full
|
|
self.print_and_status_update("Loading OmniGen2 model")
|
|
# will be updated if we detect a existing checkpoint in training folder
|
|
model_path = self.model_config.name_or_path
|
|
extras_path = self.model_config.extras_name_or_path
|
|
|
|
scheduler = OmniGen2Model.get_train_scheduler()
|
|
|
|
self.print_and_status_update("Loading Qwen2.5 VL")
|
|
processor = CLIPProcessor.from_pretrained(
|
|
extras_path, subfolder="processor", use_fast=True
|
|
)
|
|
|
|
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
extras_path, subfolder="mllm", torch_dtype=torch.bfloat16
|
|
)
|
|
mllm.to(self.device_torch, dtype=dtype)
|
|
if self.model_config.quantize_te:
|
|
self.print_and_status_update("Quantizing Qwen2.5 VL model")
|
|
quantization_type = get_qtype(self.model_config.qtype_te)
|
|
quantize(mllm, weights=quantization_type)
|
|
freeze(mllm)
|
|
|
|
if self.low_vram:
|
|
# unload it for now
|
|
mllm.to("cpu")
|
|
|
|
flush()
|
|
|
|
self.print_and_status_update("Loading transformer")
|
|
|
|
transformer = OmniGen2Transformer2DModel.from_pretrained(
|
|
model_path, subfolder="transformer", torch_dtype=torch.bfloat16
|
|
)
|
|
|
|
if not self.low_vram:
|
|
transformer.to(self.device_torch, dtype=dtype)
|
|
|
|
if self.model_config.quantize:
|
|
self.print_and_status_update("Quantizing transformer")
|
|
quantization_type = get_qtype(self.model_config.qtype)
|
|
quantize(transformer, weights=quantization_type)
|
|
freeze(transformer)
|
|
|
|
if self.low_vram:
|
|
# unload it for now
|
|
transformer.to("cpu")
|
|
|
|
flush()
|
|
|
|
self.print_and_status_update("Loading vae")
|
|
|
|
vae = AutoencoderKL.from_pretrained(
|
|
extras_path, subfolder="vae", torch_dtype=torch.bfloat16
|
|
).to(self.device_torch, dtype=dtype)
|
|
|
|
flush()
|
|
self.print_and_status_update("Loading Qwen2.5 VLProcessor")
|
|
|
|
flush()
|
|
|
|
if self.low_vram:
|
|
self.print_and_status_update("Moving everything to device")
|
|
# move it all back
|
|
transformer.to(self.device_torch, dtype=dtype)
|
|
vae.to(self.device_torch, dtype=dtype)
|
|
mllm.to(self.device_torch, dtype=dtype)
|
|
|
|
# set to eval mode
|
|
# transformer.eval()
|
|
vae.eval()
|
|
mllm.eval()
|
|
mllm.requires_grad_(False)
|
|
|
|
pipe: OmniGen2Pipeline = OmniGen2Pipeline(
|
|
transformer=transformer,
|
|
vae=vae,
|
|
scheduler=scheduler,
|
|
mllm=mllm,
|
|
processor=processor,
|
|
)
|
|
|
|
flush()
|
|
|
|
text_encoder_list = [mllm]
|
|
tokenizer_list = [processor]
|
|
|
|
flush()
|
|
|
|
# save it to the model class
|
|
self.vae = vae
|
|
self.text_encoder = text_encoder_list # list of text encoders
|
|
self.tokenizer = tokenizer_list # list of tokenizers
|
|
self.model = pipe.transformer
|
|
self.pipeline = pipe
|
|
|
|
self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
|
|
transformer.config.axes_dim_rope,
|
|
transformer.config.axes_lens,
|
|
theta=10000,
|
|
)
|
|
|
|
self.print_and_status_update("Model Loaded")
|
|
|
|
def get_generation_pipeline(self):
|
|
scheduler = OmniFlowMatchEuler(
|
|
dynamic_time_shift=True, num_train_timesteps=1000
|
|
)
|
|
|
|
pipeline: OmniGen2Pipeline = OmniGen2Pipeline(
|
|
transformer=self.model,
|
|
vae=self.vae,
|
|
scheduler=scheduler,
|
|
mllm=self.text_encoder[0],
|
|
processor=self.tokenizer[0],
|
|
)
|
|
|
|
pipeline = pipeline.to(self.device_torch)
|
|
|
|
return pipeline
|
|
|
|
def generate_single_image(
|
|
self,
|
|
pipeline: OmniGen2Pipeline,
|
|
gen_config: GenerateImageConfig,
|
|
conditional_embeds: PromptEmbeds,
|
|
unconditional_embeds: PromptEmbeds,
|
|
generator: torch.Generator,
|
|
extra: dict,
|
|
):
|
|
input_images = []
|
|
if gen_config.ctrl_img is not None:
|
|
control_img = Image.open(gen_config.ctrl_img)
|
|
control_img = control_img.convert("RGB")
|
|
# resize to width and height
|
|
if control_img.size != (gen_config.width, gen_config.height):
|
|
control_img = control_img.resize(
|
|
(gen_config.width, gen_config.height), Image.BILINEAR
|
|
)
|
|
input_images = [control_img]
|
|
|
|
img = pipeline(
|
|
prompt_embeds=conditional_embeds.text_embeds,
|
|
prompt_attention_mask=conditional_embeds.attention_mask,
|
|
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
|
negative_prompt_attention_mask=unconditional_embeds.attention_mask,
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
text_guidance_scale=gen_config.guidance_scale,
|
|
image_guidance_scale=1.0, # reference image guidance scale. Add this for controls
|
|
latents=gen_config.latents,
|
|
align_res=False,
|
|
generator=generator,
|
|
input_images=input_images,
|
|
**extra,
|
|
).images[0]
|
|
return img
|
|
|
|
def get_noise_prediction(
|
|
self,
|
|
latent_model_input: torch.Tensor,
|
|
timestep: torch.Tensor, # 0 to 1000 scale
|
|
text_embeddings: PromptEmbeds,
|
|
**kwargs,
|
|
):
|
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
try:
|
|
timestep = timestep.expand(latent_model_input.shape[0]).to(
|
|
latent_model_input.dtype
|
|
)
|
|
except Exception as e:
|
|
pass
|
|
|
|
timesteps = timestep / 1000 # convert to 0 to 1 scale
|
|
# timestep for model starts at 0 instead of 1. So we need to reverse them
|
|
timestep = 1 - timesteps
|
|
model_pred = self.model(
|
|
latent_model_input,
|
|
timestep,
|
|
text_embeddings.text_embeds,
|
|
self.freqs_cis,
|
|
text_embeddings.attention_mask,
|
|
ref_image_hidden_states=self._control_latent,
|
|
)
|
|
|
|
return model_pred
|
|
|
|
def condition_noisy_latents(
|
|
self, latents: torch.Tensor, batch: "DataLoaderBatchDTO"
|
|
):
|
|
# reset the control latent
|
|
self._control_latent = None
|
|
with torch.no_grad():
|
|
control_tensor = batch.control_tensor
|
|
if control_tensor is not None:
|
|
self.vae.to(self.device_torch)
|
|
# we are not packed here, so we just need to pass them so we can pack them later
|
|
control_tensor = control_tensor * 2 - 1
|
|
control_tensor = control_tensor.to(
|
|
self.vae_device_torch, dtype=self.torch_dtype
|
|
)
|
|
|
|
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
|
# todo, we may not need to do this, check
|
|
if batch.tensor is not None:
|
|
target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3]
|
|
else:
|
|
# When caching latents, batch.tensor is None. We get the size from the file_items instead.
|
|
target_h = batch.file_items[0].crop_height
|
|
target_w = batch.file_items[0].crop_width
|
|
|
|
if (
|
|
control_tensor.shape[2] != target_h
|
|
or control_tensor.shape[3] != target_w
|
|
):
|
|
control_tensor = F.interpolate(
|
|
control_tensor, size=(target_h, target_w), mode="bilinear"
|
|
)
|
|
|
|
control_latent = self.encode_images(control_tensor).to(
|
|
latents.device, latents.dtype
|
|
)
|
|
self._control_latent = [
|
|
[x.squeeze(0)]
|
|
for x in torch.chunk(control_latent, control_latent.shape[0], dim=0)
|
|
]
|
|
|
|
return latents.detach()
|
|
|
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt]
|
|
self.text_encoder_to(self.device_torch, dtype=self.torch_dtype)
|
|
max_sequence_length = 256
|
|
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
|
|
prompt=prompt,
|
|
do_classifier_free_guidance=False,
|
|
device=self.device_torch,
|
|
max_sequence_length=max_sequence_length,
|
|
)
|
|
pe = PromptEmbeds(prompt_embeds)
|
|
pe.attention_mask = prompt_attention_mask
|
|
return pe
|
|
|
|
def get_model_has_grad(self):
|
|
# return from a weight if it has grad
|
|
return False
|
|
|
|
def get_te_has_grad(self):
|
|
# assume no one wants to finetune 4 text encoders.
|
|
return False
|
|
|
|
def save_model(self, output_path, meta, save_dtype):
|
|
# only save the transformer
|
|
transformer: OmniGen2Transformer2DModel = unwrap_model(self.model)
|
|
transformer.save_pretrained(
|
|
save_directory=os.path.join(output_path, "transformer"),
|
|
safe_serialization=True,
|
|
)
|
|
|
|
meta_path = os.path.join(output_path, "aitk_meta.yaml")
|
|
with open(meta_path, "w") as f:
|
|
yaml.dump(meta, f)
|
|
|
|
def get_loss_target(self, *args, **kwargs):
|
|
noise = kwargs.get("noise")
|
|
batch = kwargs.get("batch")
|
|
# return (noise - batch.latents).detach()
|
|
return (batch.latents - noise).detach()
|
|
|
|
def get_transformer_block_names(self) -> Optional[List[str]]:
|
|
# omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers.
|
|
# lets do all but image refiner until we add it
|
|
if self.model_config.model_kwargs.get("use_image_refiner", False):
|
|
return ["noise_refiner", "context_refiner", "ref_image_refiner", "layers"]
|
|
return ["noise_refiner", "context_refiner", "layers"]
|
|
|
|
def convert_lora_weights_before_save(self, state_dict):
|
|
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
|
|
new_sd = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace("transformer.", "diffusion_model.")
|
|
new_sd[new_key] = value
|
|
return new_sd
|
|
|
|
def convert_lora_weights_before_load(self, state_dict):
|
|
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
|
|
new_sd = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace("diffusion_model.", "transformer.")
|
|
new_sd[new_key] = value
|
|
return new_sd
|
|
|
|
def get_base_model_version(self):
|
|
return "omnigen2"
|