Allow control image for omnigen training and sampling

This commit is contained in:
Jaret Burkett
2025-07-09 13:54:55 -06:00
parent bbb57de6ec
commit 611969ec1f
6 changed files with 187 additions and 132 deletions

View File

@@ -1,8 +1,6 @@
import inspect
import os
from typing import TYPE_CHECKING, List, Optional
import einops
import torch
import yaml
from toolkit.config_modules import GenerateImageConfig, ModelConfig
@@ -10,22 +8,29 @@ from toolkit.models.base_model import BaseModel
from diffusers import AutoencoderKL
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze
from toolkit.util.quantize import quantize, get_qtype
from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
from .src.models.transformers import OmniGen2Transformer2DModel
from .src.models.transformers.repo import OmniGen2RotaryPosEmbed
from .src.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler
from transformers import CLIPProcessor, Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
from .src.schedulers.scheduling_flow_match_euler_discrete import (
FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler,
)
from PIL import Image
from transformers import (
CLIPProcessor,
Qwen2_5_VLForConditionalGeneration,
)
import torch.nn.functional as F
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
scheduler_config = {
"num_train_timesteps": 1000
}
scheduler_config = {"num_train_timesteps": 1000}
BASE_MODEL_PATH = "OmniGen2/OmniGen2"
@@ -34,25 +39,21 @@ class OmniGen2Model(BaseModel):
arch = "omnigen2"
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device,
model_config,
dtype,
custom_pipeline,
noise_scheduler,
**kwargs
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['OmniGen2Transformer2DModel']
self.target_lora_modules = ["OmniGen2Transformer2DModel"]
self._control_latent = None
# static method to get the noise scheduler
@staticmethod
@@ -69,20 +70,16 @@ class OmniGen2Model(BaseModel):
# will be updated if we detect a existing checkpoint in training folder
model_path = self.model_config.name_or_path
extras_path = self.model_config.extras_name_or_path
scheduler = OmniGen2Model.get_train_scheduler()
self.print_and_status_update("Loading Qwen2.5 VL")
processor = CLIPProcessor.from_pretrained(
extras_path,
subfolder="processor",
use_fast=True
extras_path, subfolder="processor", use_fast=True
)
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
extras_path,
subfolder="mllm",
torch_dtype=torch.bfloat16
extras_path, subfolder="mllm", torch_dtype=torch.bfloat16
)
mllm.to(self.device_torch, dtype=dtype)
if self.model_config.quantize_te:
@@ -90,57 +87,52 @@ class OmniGen2Model(BaseModel):
quantization_type = get_qtype(self.model_config.qtype_te)
quantize(mllm, weights=quantization_type)
freeze(mllm)
if self.low_vram:
# unload it for now
mllm.to('cpu')
mllm.to("cpu")
flush()
self.print_and_status_update("Loading transformer")
transformer = OmniGen2Transformer2DModel.from_pretrained(
model_path,
subfolder="transformer",
torch_dtype=torch.bfloat16
model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)
if not self.low_vram:
transformer.to(self.device_torch, dtype=dtype)
if self.model_config.quantize:
self.print_and_status_update("Quantizing transformer")
quantization_type = get_qtype(self.model_config.qtype)
quantize(transformer, weights=quantization_type)
freeze(transformer)
if self.low_vram:
# unload it for now
transformer.to('cpu')
transformer.to("cpu")
flush()
self.print_and_status_update("Loading vae")
vae = AutoencoderKL.from_pretrained(
extras_path,
subfolder="vae",
torch_dtype=torch.bfloat16
extras_path, subfolder="vae", torch_dtype=torch.bfloat16
).to(self.device_torch, dtype=dtype)
flush()
self.print_and_status_update("Loading Qwen2.5 VLProcessor")
flush()
if self.low_vram:
self.print_and_status_update("Moving everything to device")
# move it all back
transformer.to(self.device_torch, dtype=dtype)
vae.to(self.device_torch, dtype=dtype)
mllm.to(self.device_torch, dtype=dtype)
# set to eval mode
# transformer.eval()
vae.eval()
@@ -149,28 +141,17 @@ class OmniGen2Model(BaseModel):
pipe: OmniGen2Pipeline = OmniGen2Pipeline(
transformer=transformer,
vae=vae,
vae=vae,
scheduler=scheduler,
mllm=mllm,
processor=processor,
)
# pipe: OmniGen2Pipeline = OmniGen2Pipeline.from_pretrained(
# model_path,
# transformer=transformer,
# vae=vae,
# scheduler=scheduler,
# mllm=mllm,
# trust_remote_code=True,
# )
# processor = pipe.processor
flush()
text_encoder_list = [mllm]
tokenizer_list = [processor]
flush()
# save it to the model class
@@ -179,21 +160,20 @@ class OmniGen2Model(BaseModel):
self.tokenizer = tokenizer_list # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
transformer.config.axes_dim_rope,
transformer.config.axes_lens,
theta=10000,
)
self.print_and_status_update("Model Loaded")
def get_generation_pipeline(self):
scheduler = OmniFlowMatchEuler(
dynamic_time_shift=True,
num_train_timesteps=1000
dynamic_time_shift=True, num_train_timesteps=1000
)
pipeline: OmniGen2Pipeline = OmniGen2Pipeline(
transformer=self.model,
vae=self.vae,
@@ -215,6 +195,17 @@ class OmniGen2Model(BaseModel):
generator: torch.Generator,
extra: dict,
):
input_images = []
if gen_config.ctrl_img is not None:
control_img = Image.open(gen_config.ctrl_img)
control_img = control_img.convert("RGB")
# resize to width and height
if control_img.size != (gen_config.width, gen_config.height):
control_img = control_img.resize(
(gen_config.width, gen_config.height), Image.BILINEAR
)
input_images = [control_img]
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
prompt_attention_mask=conditional_embeds.attention_mask,
@@ -224,10 +215,12 @@ class OmniGen2Model(BaseModel):
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
text_guidance_scale=gen_config.guidance_scale,
image_guidance_scale=1.0, # reference image guidance scale. Add this for controls
image_guidance_scale=1.0, # reference image guidance scale. Add this for controls
latents=gen_config.latents,
align_res=False,
generator=generator,
**extra
input_images=input_images,
**extra,
).images[0]
return img
@@ -236,18 +229,16 @@ class OmniGen2Model(BaseModel):
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
**kwargs,
):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
try:
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
timestep = timestep.expand(latent_model_input.shape[0]).to(
latent_model_input.dtype
)
except Exception as e:
pass
# optional_kwargs = {}
# if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()):
# optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
timesteps = timestep / 1000 # convert to 0 to 1 scale
# timestep for model starts at 0 instead of 1. So we need to reverse them
timestep = 1 - timesteps
@@ -257,18 +248,60 @@ class OmniGen2Model(BaseModel):
text_embeddings.text_embeds,
self.freqs_cis,
text_embeddings.attention_mask,
ref_image_hidden_states=None, # todo add ref latent ability
ref_image_hidden_states=self._control_latent,
)
return model_pred
def condition_noisy_latents(
self, latents: torch.Tensor, batch: "DataLoaderBatchDTO"
):
# reset the control latent
self._control_latent = None
with torch.no_grad():
control_tensor = batch.control_tensor
if control_tensor is not None:
self.vae.to(self.device_torch)
# we are not packed here, so we just need to pass them so we can pack them later
control_tensor = control_tensor * 2 - 1
control_tensor = control_tensor.to(
self.vae_device_torch, dtype=self.torch_dtype
)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
# todo, we may not need to do this, check
if batch.tensor is not None:
target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3]
else:
# When caching latents, batch.tensor is None. We get the size from the file_items instead.
target_h = batch.file_items[0].crop_height
target_w = batch.file_items[0].crop_width
if (
control_tensor.shape[2] != target_h
or control_tensor.shape[3] != target_w
):
control_tensor = F.interpolate(
control_tensor, size=(target_h, target_w), mode="bilinear"
)
control_latent = self.encode_images(control_tensor).to(
latents.device, latents.dtype
)
self._control_latent = [
[x.squeeze(0)]
for x in torch.chunk(control_latent, control_latent.shape[0], dim=0)
]
return latents.detach()
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt]
self.text_encoder_to(self.device_torch, dtype=self.torch_dtype)
max_sequence_length = 256
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
prompt = prompt,
prompt=prompt,
do_classifier_free_guidance=False,
device=self.device_torch,
max_sequence_length=max_sequence_length,
@@ -276,7 +309,7 @@ class OmniGen2Model(BaseModel):
pe = PromptEmbeds(prompt_embeds)
pe.attention_mask = prompt_attention_mask
return pe
def get_model_has_grad(self):
# return from a weight if it has grad
return False
@@ -284,30 +317,31 @@ class OmniGen2Model(BaseModel):
def get_te_has_grad(self):
# assume no one wants to finetune 4 text encoders.
return False
def save_model(self, output_path, meta, save_dtype):
# only save the transformer
transformer: OmniGen2Transformer2DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, 'transformer'),
save_directory=os.path.join(output_path, "transformer"),
safe_serialization=True,
)
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
meta_path = os.path.join(output_path, "aitk_meta.yaml")
with open(meta_path, "w") as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
batch = kwargs.get('batch')
noise = kwargs.get("noise")
batch = kwargs.get("batch")
# return (noise - batch.latents).detach()
return (batch.latents - noise).detach()
def get_transformer_block_names(self) -> Optional[List[str]]:
# omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers.
# lets do all but image refiner until we add it
return ['noise_refiner', 'context_refiner', 'layers']
# return ['layers']
if self.model_config.model_kwargs.get("use_image_refiner", False):
return ["noise_refiner", "context_refiner", "ref_image_refiner", "layers"]
return ["noise_refiner", "context_refiner", "layers"]
def convert_lora_weights_before_save(self, state_dict):
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
@@ -324,7 +358,6 @@ class OmniGen2Model(BaseModel):
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd
def get_base_model_version(self):
return "omnigen2"

View File

@@ -676,7 +676,8 @@ class OmniGen2Pipeline(DiffusionPipeline):
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
ref_image_hidden_states=ref_latents,
# ref_image_hidden_states=None,
)
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)

View File

@@ -310,6 +310,7 @@ export default function SimpleJob({
{ value: 'sigmoid', label: 'Sigmoid' },
{ value: 'linear', label: 'Linear' },
{ value: 'shift', label: 'Shift' },
{ value: 'weighted', label: 'Weighted' },
]}
/>
)}
@@ -541,13 +542,12 @@ export default function SimpleJob({
{ value: 'ddpm', label: 'DDPM' },
]}
/>
</div>
<div>
<NumberInput
label="Guidance Scale"
value={jobConfig.config.process[0].sample.guidance_scale}
onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
placeholder="eg. 1.0"
className="pt-2"
min={0}
required
/>
@@ -579,6 +579,26 @@ export default function SimpleJob({
min={0}
required
/>
{isVideoModel && (
<div>
<NumberInput
label="Num Frames"
value={jobConfig.config.process[0].sample.num_frames}
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
placeholder="eg. 0"
min={0}
required
/>
<NumberInput
label="FPS"
value={jobConfig.config.process[0].sample.fps}
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
placeholder="eg. 0"
min={0}
required
/>
</div>
)}
</div>
<div>
@@ -597,40 +617,36 @@ export default function SimpleJob({
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
/>
</div>
{isVideoModel && (
<div>
<NumberInput
label="Num Frames"
value={jobConfig.config.process[0].sample.num_frames}
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
placeholder="eg. 0"
min={0}
required
/>
<NumberInput
label="FPS"
value={jobConfig.config.process[0].sample.fps}
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
placeholder="eg. 0"
min={0}
required
/>
</div>
)}
<div>
<FormGroup label="Advanced Sampling" className="pt-2">
<div>
<Checkbox
label="Skip First Sample"
className="pt-4"
checked={jobConfig.config.process[0].train.skip_first_sample || false}
onChange={value => setJobConfig(value, 'config.process[0].train.skip_first_sample')}
/>
</div>
<div>
<Checkbox
label="Disable Sampling"
className="pt-1"
checked={jobConfig.config.process[0].train.disable_sampling || false}
onChange={value => setJobConfig(value, 'config.process[0].train.disable_sampling')}
/>
</div>
</FormGroup>
</div>
</div>
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`} className="pt-2">
{
modelArch?.additionalSections?.includes('sample.ctrl_img') && (
<div className='text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg'>
<p className='font-semibold mb-1'>
Control Images
</p>
To use control images on samples, add --ctrl_img to the prompts below.
<br />
Example: <code className='bg-yellow-900 p-1'>make this a cartoon --ctrl_img /path/to/image.png</code>
</div>
)
}
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
<div className="text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg">
<p className="font-semibold mb-1">Control Images</p>
To use control images on samples, add --ctrl_img to the prompts below.
<br />
Example: <code className="bg-yellow-900 p-1">make this a cartoon --ctrl_img /path/to/image.png</code>
</div>
)}
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
<div key={i} className="flex items-center space-x-2">
<div className="flex-1">

View File

@@ -68,6 +68,8 @@ export const defaultJobConfig: JobConfig = {
use_ema: false,
ema_decay: 0.99,
},
skip_first_sample: false,
disable_sampling: false,
dtype: 'bf16',
diff_output_preservation: false,
diff_output_preservation_multiplier: 1.0,

View File

@@ -200,6 +200,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].model.quantize_te': [true, false],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
},
].sort((a, b) => {
// Sort by label, case-insensitive

View File

@@ -110,6 +110,8 @@ export interface TrainConfig {
optimizer_params: {
weight_decay: number;
};
skip_first_sample: boolean;
disable_sampling: boolean;
diff_output_preservation: boolean;
diff_output_preservation_multiplier: number;
diff_output_preservation_class: string;