diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index a2d5df69..d8bf232b 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -3,7 +3,7 @@ from .hidream import HidreamModel, HidreamE1Model from .f_light import FLiteModel from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel -from .wan22 import Wan225bModel, Wan2214bModel +from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel from .qwen_image import QwenImageModel AI_TOOLKIT_MODELS = [ @@ -15,6 +15,7 @@ AI_TOOLKIT_MODELS = [ OmniGen2Model, FluxKontextModel, Wan225bModel, + Wan2214bI2VModel, Wan2214bModel, QwenImageModel, ] diff --git a/extensions_built_in/diffusion_models/wan22/__init__.py b/extensions_built_in/diffusion_models/wan22/__init__.py index 5c88152b..61b5e803 100644 --- a/extensions_built_in/diffusion_models/wan22/__init__.py +++ b/extensions_built_in/diffusion_models/wan22/__init__.py @@ -1,2 +1,3 @@ from .wan22_5b_model import Wan225bModel -from .wan22_14b_model import Wan2214bModel \ No newline at end of file +from .wan22_14b_model import Wan2214bModel +from .wan22_14b_i2v_model import Wan2214bI2VModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py new file mode 100644 index 00000000..32eb11e8 --- /dev/null +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py @@ -0,0 +1,144 @@ +import torch +from toolkit.models.wan21.wan_utils import add_first_frame_conditioning +from toolkit.prompt_utils import PromptEmbeds +from PIL import Image +import torch +from toolkit.config_modules import GenerateImageConfig +from .wan22_pipeline import Wan22Pipeline + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from diffusers import WanImageToVideoPipeline +from torchvision.transforms import functional as TF + +from .wan22_14b_model import Wan2214bModel + +class Wan2214bI2VModel(Wan2214bModel): + arch = "wan22_14b_i2v" + + + def generate_single_image( + self, + pipeline: Wan22Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + + # todo + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + + num_frames = ( + (gen_config.num_frames - 1) // 4 + ) * 4 + 1 # make sure it is divisible by 4 + 1 + gen_config.num_frames = num_frames + + height = gen_config.height + width = gen_config.width + first_frame_n1p1 = None + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img).convert("RGB") + + d = self.get_bucket_divisibility() + + # make sure they are divisible by d + height = height // d * d + width = width // d * d + + # resize the control image + control_img = control_img.resize((width, height), Image.LANCZOS) + + # 5. Prepare latent variables + # num_channels_latents = self.transformer.config.in_channels + num_channels_latents = 16 + latents = pipeline.prepare_latents( + 1, + num_channels_latents, + height, + width, + gen_config.num_frames, + torch.float32, + self.device_torch, + generator, + None, + ).to(self.torch_dtype) + + first_frame_n1p1 = ( + TF.to_tensor(control_img) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + * 2.0 + - 1.0 + ) # normalize to [-1, 1] + + # Add conditioning using the standalone function + gen_config.latents = add_first_frame_conditioning( + latent_model_input=latents, + first_frame=first_frame_n1p1, + vae=self.vae + ) + + output = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + height=height, + width=width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="pil", + **extra, + )[0] + + # shape = [1, frames, channels, height, width] + batch_item = output[0] # list of pil images + if gen_config.num_frames > 1: + return batch_item # return the frames. + else: + # get just the first image + img = batch_item[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: DataLoaderBatchDTO, + **kwargs + ): + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + with torch.no_grad(): + frames = batch.tensor + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + + # Add conditioning using the standalone function + conditioned_latent = add_first_frame_conditioning( + latent_model_input=latent_model_input, + first_frame=first_frames, + vae=self.vae + ) + + noise_pred = self.model( + hidden_states=conditioned_latent, + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds, + return_dict=False, + **kwargs + )[0] + return noise_pred \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index abc1803e..aee75272 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -6,6 +6,7 @@ import torch import yaml from toolkit.accelerator import unwrap_model from toolkit.basic import flush +from toolkit.models.wan21.wan_utils import add_first_frame_conditioning from toolkit.prompt_utils import PromptEmbeds from PIL import Image from diffusers import UniPCMultistepScheduler @@ -21,11 +22,10 @@ from diffusers import WanTransformer3DModel from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from torchvision.transforms import functional as TF -from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline, Wan21 +from toolkit.models.wan21.wan21 import Wan21 from .wan22_5b_model import ( scheduler_config, time_text_monkeypatch, - Wan225bModel, ) from safetensors.torch import load_file, save_file @@ -239,8 +239,8 @@ class Wan2214bModel(Wan21): ) def get_bucket_divisibility(self): - # 16x compression and 2x2 patch size - return 32 + # 8x compression and 2x2 patch size + return 16 def load_wan_transformer(self, transformer_path, subfolder=None): if self.model_config.split_model_over_gpus: @@ -378,7 +378,7 @@ class Wan2214bModel(Wan21): def generate_single_image( self, - pipeline: AggressiveWanUnloadPipeline, + pipeline: Wan22Pipeline, gen_config: GenerateImageConfig, conditional_embeds: PromptEmbeds, unconditional_embeds: PromptEmbeds, @@ -513,7 +513,7 @@ class Wan2214bModel(Wan21): combined_dict[new_key] = low_noise_lora[key] # if we are not training both stages, we wont have transformer designations in the keys - if not self.train_high_noise and not self.train_low_noise: + if not self.train_high_noise or not self.train_low_noise: new_dict = {} for key in combined_dict: if ".transformer_1." in key: diff --git a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py index dafa2012..f434327f 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py @@ -10,6 +10,7 @@ from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.image_processor import PipelineImageInput class Wan22Pipeline(WanPipeline): @@ -149,6 +150,18 @@ class Wan22Pipeline(WanPipeline): # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels + + conditioning = None # wan2.2 i2v conditioning + # check shape of latents to see if it is first frame conditioned for 2.2 14b i2v + if latents is not None: + if latents.shape[1] == 36: + # first 16 channels are latent. other 20 are conditioning + conditioning = latents[:, 16:] + latents = latents[:, :16] + + # we need to trick the in_channls to think it is only 16 channels + num_channels_latents = 16 + latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, @@ -210,6 +223,13 @@ class Wan22Pipeline(WanPipeline): timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) else: timestep = t.expand(latents.shape[0]) + + pre_condition_latent_model_input = latent_model_input.clone() + + if conditioning is not None: + # conditioning is first frame conditioning for 2.2 i2v + latent_model_input = torch.cat( + [latent_model_input, conditioning], dim=1) noise_pred = current_model( hidden_states=latent_model_input, @@ -235,7 +255,7 @@ class Wan22Pipeline(WanPipeline): noise_pred, t, latents, return_dict=False)[0] # apply i2v mask - latents = (latent_model_input * (1 - mask)) + ( + latents = (pre_condition_latent_model_input * (1 - mask)) + ( latents * mask ) diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 581d2c83..c180c240 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -218,6 +218,33 @@ export const modelArchs: ModelArch[] = [ // '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors', // }, }, + { + name: 'wan22_14b_i2v', + label: 'Wan 2.2 I2V (14B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-I2V-A14B-Diffusers-bf16', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].model.model_kwargs': [ + { + train_high_noise: true, + train_low_noise: true, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage'], + }, { name: 'wan22_5b', label: 'Wan 2.2 TI2V (5B)', diff --git a/version.py b/version.py index 9b13ade4..214abb41 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.5.1" \ No newline at end of file +VERSION = "0.5.2" \ No newline at end of file