diff --git a/extensions_built_in/diffusion_models/omnigen2/__init__.py b/extensions_built_in/diffusion_models/omnigen2/__init__.py index ddff811e..77ce910d 100644 --- a/extensions_built_in/diffusion_models/omnigen2/__init__.py +++ b/extensions_built_in/diffusion_models/omnigen2/__init__.py @@ -1,8 +1,6 @@ -import inspect import os from typing import TYPE_CHECKING, List, Optional -import einops import torch import yaml from toolkit.config_modules import GenerateImageConfig, ModelConfig @@ -10,22 +8,29 @@ from toolkit.models.base_model import BaseModel from diffusers import AutoencoderKL from toolkit.basic import flush from toolkit.prompt_utils import PromptEmbeds -from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) from toolkit.accelerator import unwrap_model from optimum.quanto import freeze from toolkit.util.quantize import quantize, get_qtype from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline from .src.models.transformers import OmniGen2Transformer2DModel from .src.models.transformers.repo import OmniGen2RotaryPosEmbed -from .src.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler -from transformers import CLIPProcessor, Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration +from .src.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler, +) +from PIL import Image +from transformers import ( + CLIPProcessor, + Qwen2_5_VLForConditionalGeneration, +) +import torch.nn.functional as F if TYPE_CHECKING: from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO -scheduler_config = { - "num_train_timesteps": 1000 -} +scheduler_config = {"num_train_timesteps": 1000} BASE_MODEL_PATH = "OmniGen2/OmniGen2" @@ -34,25 +39,21 @@ class OmniGen2Model(BaseModel): arch = "omnigen2" def __init__( - self, - device, - model_config: ModelConfig, - dtype='bf16', - custom_pipeline=None, - noise_scheduler=None, - **kwargs + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, ): super().__init__( - device, - model_config, - dtype, - custom_pipeline, - noise_scheduler, - **kwargs + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs ) self.is_flow_matching = True self.is_transformer = True - self.target_lora_modules = ['OmniGen2Transformer2DModel'] + self.target_lora_modules = ["OmniGen2Transformer2DModel"] + self._control_latent = None # static method to get the noise scheduler @staticmethod @@ -69,20 +70,16 @@ class OmniGen2Model(BaseModel): # will be updated if we detect a existing checkpoint in training folder model_path = self.model_config.name_or_path extras_path = self.model_config.extras_name_or_path - + scheduler = OmniGen2Model.get_train_scheduler() - + self.print_and_status_update("Loading Qwen2.5 VL") processor = CLIPProcessor.from_pretrained( - extras_path, - subfolder="processor", - use_fast=True + extras_path, subfolder="processor", use_fast=True ) - + mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( - extras_path, - subfolder="mllm", - torch_dtype=torch.bfloat16 + extras_path, subfolder="mllm", torch_dtype=torch.bfloat16 ) mllm.to(self.device_torch, dtype=dtype) if self.model_config.quantize_te: @@ -90,57 +87,52 @@ class OmniGen2Model(BaseModel): quantization_type = get_qtype(self.model_config.qtype_te) quantize(mllm, weights=quantization_type) freeze(mllm) - + if self.low_vram: # unload it for now - mllm.to('cpu') - + mllm.to("cpu") + flush() - + self.print_and_status_update("Loading transformer") - + transformer = OmniGen2Transformer2DModel.from_pretrained( - model_path, - subfolder="transformer", - torch_dtype=torch.bfloat16 + model_path, subfolder="transformer", torch_dtype=torch.bfloat16 ) - + if not self.low_vram: transformer.to(self.device_torch, dtype=dtype) - + if self.model_config.quantize: self.print_and_status_update("Quantizing transformer") quantization_type = get_qtype(self.model_config.qtype) quantize(transformer, weights=quantization_type) freeze(transformer) - + if self.low_vram: # unload it for now - transformer.to('cpu') - + transformer.to("cpu") + flush() - + self.print_and_status_update("Loading vae") - + vae = AutoencoderKL.from_pretrained( - extras_path, - subfolder="vae", - torch_dtype=torch.bfloat16 + extras_path, subfolder="vae", torch_dtype=torch.bfloat16 ).to(self.device_torch, dtype=dtype) - - + flush() self.print_and_status_update("Loading Qwen2.5 VLProcessor") - + flush() - + if self.low_vram: self.print_and_status_update("Moving everything to device") # move it all back transformer.to(self.device_torch, dtype=dtype) vae.to(self.device_torch, dtype=dtype) mllm.to(self.device_torch, dtype=dtype) - + # set to eval mode # transformer.eval() vae.eval() @@ -149,28 +141,17 @@ class OmniGen2Model(BaseModel): pipe: OmniGen2Pipeline = OmniGen2Pipeline( transformer=transformer, - vae=vae, + vae=vae, scheduler=scheduler, mllm=mllm, processor=processor, ) - # pipe: OmniGen2Pipeline = OmniGen2Pipeline.from_pretrained( - # model_path, - # transformer=transformer, - # vae=vae, - # scheduler=scheduler, - # mllm=mllm, - # trust_remote_code=True, - # ) - # processor = pipe.processor - flush() - + text_encoder_list = [mllm] tokenizer_list = [processor] - - + flush() # save it to the model class @@ -179,21 +160,20 @@ class OmniGen2Model(BaseModel): self.tokenizer = tokenizer_list # list of tokenizers self.model = pipe.transformer self.pipeline = pipe - + self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( transformer.config.axes_dim_rope, transformer.config.axes_lens, theta=10000, ) - + self.print_and_status_update("Model Loaded") def get_generation_pipeline(self): scheduler = OmniFlowMatchEuler( - dynamic_time_shift=True, - num_train_timesteps=1000 + dynamic_time_shift=True, num_train_timesteps=1000 ) - + pipeline: OmniGen2Pipeline = OmniGen2Pipeline( transformer=self.model, vae=self.vae, @@ -215,6 +195,17 @@ class OmniGen2Model(BaseModel): generator: torch.Generator, extra: dict, ): + input_images = [] + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + # resize to width and height + if control_img.size != (gen_config.width, gen_config.height): + control_img = control_img.resize( + (gen_config.width, gen_config.height), Image.BILINEAR + ) + input_images = [control_img] + img = pipeline( prompt_embeds=conditional_embeds.text_embeds, prompt_attention_mask=conditional_embeds.attention_mask, @@ -224,10 +215,12 @@ class OmniGen2Model(BaseModel): width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, text_guidance_scale=gen_config.guidance_scale, - image_guidance_scale=1.0, # reference image guidance scale. Add this for controls + image_guidance_scale=1.0, # reference image guidance scale. Add this for controls latents=gen_config.latents, + align_res=False, generator=generator, - **extra + input_images=input_images, + **extra, ).images[0] return img @@ -236,18 +229,16 @@ class OmniGen2Model(BaseModel): latent_model_input: torch.Tensor, timestep: torch.Tensor, # 0 to 1000 scale text_embeddings: PromptEmbeds, - **kwargs + **kwargs, ): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML try: - timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + timestep = timestep.expand(latent_model_input.shape[0]).to( + latent_model_input.dtype + ) except Exception as e: pass - - # optional_kwargs = {} - # if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()): - # optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states - + timesteps = timestep / 1000 # convert to 0 to 1 scale # timestep for model starts at 0 instead of 1. So we need to reverse them timestep = 1 - timesteps @@ -257,18 +248,60 @@ class OmniGen2Model(BaseModel): text_embeddings.text_embeds, self.freqs_cis, text_embeddings.attention_mask, - ref_image_hidden_states=None, # todo add ref latent ability + ref_image_hidden_states=self._control_latent, ) return model_pred - + + def condition_noisy_latents( + self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" + ): + # reset the control latent + self._control_latent = None + with torch.no_grad(): + control_tensor = batch.control_tensor + if control_tensor is not None: + self.vae.to(self.device_torch) + # we are not packed here, so we just need to pass them so we can pack them later + control_tensor = control_tensor * 2 - 1 + control_tensor = control_tensor.to( + self.vae_device_torch, dtype=self.torch_dtype + ) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + # todo, we may not need to do this, check + if batch.tensor is not None: + target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] + else: + # When caching latents, batch.tensor is None. We get the size from the file_items instead. + target_h = batch.file_items[0].crop_height + target_w = batch.file_items[0].crop_width + + if ( + control_tensor.shape[2] != target_h + or control_tensor.shape[3] != target_w + ): + control_tensor = F.interpolate( + control_tensor, size=(target_h, target_w), mode="bilinear" + ) + + control_latent = self.encode_images(control_tensor).to( + latents.device, latents.dtype + ) + self._control_latent = [ + [x.squeeze(0)] + for x in torch.chunk(control_latent, control_latent.shape[0], dim=0) + ] + + return latents.detach() + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt] self.text_encoder_to(self.device_torch, dtype=self.torch_dtype) max_sequence_length = 256 prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt( - prompt = prompt, + prompt=prompt, do_classifier_free_guidance=False, device=self.device_torch, max_sequence_length=max_sequence_length, @@ -276,7 +309,7 @@ class OmniGen2Model(BaseModel): pe = PromptEmbeds(prompt_embeds) pe.attention_mask = prompt_attention_mask return pe - + def get_model_has_grad(self): # return from a weight if it has grad return False @@ -284,30 +317,31 @@ class OmniGen2Model(BaseModel): def get_te_has_grad(self): # assume no one wants to finetune 4 text encoders. return False - + def save_model(self, output_path, meta, save_dtype): # only save the transformer transformer: OmniGen2Transformer2DModel = unwrap_model(self.model) transformer.save_pretrained( - save_directory=os.path.join(output_path, 'transformer'), + save_directory=os.path.join(output_path, "transformer"), safe_serialization=True, ) - meta_path = os.path.join(output_path, 'aitk_meta.yaml') - with open(meta_path, 'w') as f: + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: yaml.dump(meta, f) def get_loss_target(self, *args, **kwargs): - noise = kwargs.get('noise') - batch = kwargs.get('batch') + noise = kwargs.get("noise") + batch = kwargs.get("batch") # return (noise - batch.latents).detach() return (batch.latents - noise).detach() - + def get_transformer_block_names(self) -> Optional[List[str]]: # omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers. # lets do all but image refiner until we add it - return ['noise_refiner', 'context_refiner', 'layers'] - # return ['layers'] + if self.model_config.model_kwargs.get("use_image_refiner", False): + return ["noise_refiner", "context_refiner", "ref_image_refiner", "layers"] + return ["noise_refiner", "context_refiner", "layers"] def convert_lora_weights_before_save(self, state_dict): # currently starte with transformer. but needs to start with diffusion_model. for comfyui @@ -324,7 +358,6 @@ class OmniGen2Model(BaseModel): new_key = key.replace("diffusion_model.", "transformer.") new_sd[new_key] = value return new_sd - + def get_base_model_version(self): return "omnigen2" - diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py index 651f1add..a48548f1 100644 --- a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py +++ b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py @@ -676,7 +676,8 @@ class OmniGen2Pipeline(DiffusionPipeline): prompt_embeds=negative_prompt_embeds, freqs_cis=freqs_cis, prompt_attention_mask=negative_prompt_attention_mask, - ref_image_hidden_states=None, + ref_image_hidden_states=ref_latents, + # ref_image_hidden_states=None, ) model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 4949f94a..82675e3b 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -310,6 +310,7 @@ export default function SimpleJob({ { value: 'sigmoid', label: 'Sigmoid' }, { value: 'linear', label: 'Linear' }, { value: 'shift', label: 'Shift' }, + { value: 'weighted', label: 'Weighted' }, ]} /> )} @@ -541,13 +542,12 @@ export default function SimpleJob({ { value: 'ddpm', label: 'DDPM' }, ]} /> - -
- Control Images -
- To use control images on samples, add --ctrl_img to the prompts below. -make this a cartoon --ctrl_img /path/to/image.png
- Control Images
+ To use control images on samples, add --ctrl_img to the prompts below. +make this a cartoon --ctrl_img /path/to/image.png
+