Add support for fine tuning Wan 2.2 I2V 14B

This commit is contained in:
Jaret Burkett
2025-08-18 11:43:32 -06:00
parent b3e666daf4
commit d2bbe1872c
7 changed files with 203 additions and 10 deletions

View File

@@ -3,7 +3,7 @@ from .hidream import HidreamModel, HidreamE1Model
from .f_light import FLiteModel
from .omnigen2 import OmniGen2Model
from .flux_kontext import FluxKontextModel
from .wan22 import Wan225bModel, Wan2214bModel
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
from .qwen_image import QwenImageModel
AI_TOOLKIT_MODELS = [
@@ -15,6 +15,7 @@ AI_TOOLKIT_MODELS = [
OmniGen2Model,
FluxKontextModel,
Wan225bModel,
Wan2214bI2VModel,
Wan2214bModel,
QwenImageModel,
]

View File

@@ -1,2 +1,3 @@
from .wan22_5b_model import Wan225bModel
from .wan22_14b_model import Wan2214bModel
from .wan22_14b_model import Wan2214bModel
from .wan22_14b_i2v_model import Wan2214bI2VModel

View File

@@ -0,0 +1,144 @@
import torch
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning
from toolkit.prompt_utils import PromptEmbeds
from PIL import Image
import torch
from toolkit.config_modules import GenerateImageConfig
from .wan22_pipeline import Wan22Pipeline
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from diffusers import WanImageToVideoPipeline
from torchvision.transforms import functional as TF
from .wan22_14b_model import Wan2214bModel
class Wan2214bI2VModel(Wan2214bModel):
arch = "wan22_14b_i2v"
def generate_single_image(
self,
pipeline: Wan22Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
# todo
# reactivate progress bar since this is slooooow
pipeline.set_progress_bar_config(disable=False)
num_frames = (
(gen_config.num_frames - 1) // 4
) * 4 + 1 # make sure it is divisible by 4 + 1
gen_config.num_frames = num_frames
height = gen_config.height
width = gen_config.width
first_frame_n1p1 = None
if gen_config.ctrl_img is not None:
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
d = self.get_bucket_divisibility()
# make sure they are divisible by d
height = height // d * d
width = width // d * d
# resize the control image
control_img = control_img.resize((width, height), Image.LANCZOS)
# 5. Prepare latent variables
# num_channels_latents = self.transformer.config.in_channels
num_channels_latents = 16
latents = pipeline.prepare_latents(
1,
num_channels_latents,
height,
width,
gen_config.num_frames,
torch.float32,
self.device_torch,
generator,
None,
).to(self.torch_dtype)
first_frame_n1p1 = (
TF.to_tensor(control_img)
.unsqueeze(0)
.to(self.device_torch, dtype=self.torch_dtype)
* 2.0
- 1.0
) # normalize to [-1, 1]
# Add conditioning using the standalone function
gen_config.latents = add_first_frame_conditioning(
latent_model_input=latents,
first_frame=first_frame_n1p1,
vae=self.vae
)
output = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype
),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype
),
height=height,
width=width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
num_frames=gen_config.num_frames,
generator=generator,
return_dict=False,
output_type="pil",
**extra,
)[0]
# shape = [1, frames, channels, height, width]
batch_item = output[0] # list of pil images
if gen_config.num_frames > 1:
return batch_item # return the frames.
else:
# get just the first image
img = batch_item[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
batch: DataLoaderBatchDTO,
**kwargs
):
# videos come in (bs, num_frames, channels, height, width)
# images come in (bs, channels, height, width)
with torch.no_grad():
frames = batch.tensor
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
# Add conditioning using the standalone function
conditioned_latent = add_first_frame_conditioning(
latent_model_input=latent_model_input,
first_frame=first_frames,
vae=self.vae
)
noise_pred = self.model(
hidden_states=conditioned_latent,
timestep=timestep,
encoder_hidden_states=text_embeddings.text_embeds,
return_dict=False,
**kwargs
)[0]
return noise_pred

View File

@@ -6,6 +6,7 @@ import torch
import yaml
from toolkit.accelerator import unwrap_model
from toolkit.basic import flush
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning
from toolkit.prompt_utils import PromptEmbeds
from PIL import Image
from diffusers import UniPCMultistepScheduler
@@ -21,11 +22,10 @@ from diffusers import WanTransformer3DModel
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from torchvision.transforms import functional as TF
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline, Wan21
from toolkit.models.wan21.wan21 import Wan21
from .wan22_5b_model import (
scheduler_config,
time_text_monkeypatch,
Wan225bModel,
)
from safetensors.torch import load_file, save_file
@@ -239,8 +239,8 @@ class Wan2214bModel(Wan21):
)
def get_bucket_divisibility(self):
# 16x compression and 2x2 patch size
return 32
# 8x compression and 2x2 patch size
return 16
def load_wan_transformer(self, transformer_path, subfolder=None):
if self.model_config.split_model_over_gpus:
@@ -378,7 +378,7 @@ class Wan2214bModel(Wan21):
def generate_single_image(
self,
pipeline: AggressiveWanUnloadPipeline,
pipeline: Wan22Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
@@ -513,7 +513,7 @@ class Wan2214bModel(Wan21):
combined_dict[new_key] = low_noise_lora[key]
# if we are not training both stages, we wont have transformer designations in the keys
if not self.train_high_noise and not self.train_low_noise:
if not self.train_high_noise or not self.train_low_noise:
new_dict = {}
for key in combined_dict:
if ".transformer_1." in key:

View File

@@ -10,6 +10,7 @@ from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from typing import Any, Callable, Dict, List, Optional, Union
from diffusers.image_processor import PipelineImageInput
class Wan22Pipeline(WanPipeline):
@@ -149,6 +150,18 @@ class Wan22Pipeline(WanPipeline):
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
conditioning = None # wan2.2 i2v conditioning
# check shape of latents to see if it is first frame conditioned for 2.2 14b i2v
if latents is not None:
if latents.shape[1] == 36:
# first 16 channels are latent. other 20 are conditioning
conditioning = latents[:, 16:]
latents = latents[:, :16]
# we need to trick the in_channls to think it is only 16 channels
num_channels_latents = 16
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
@@ -210,6 +223,13 @@ class Wan22Pipeline(WanPipeline):
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
else:
timestep = t.expand(latents.shape[0])
pre_condition_latent_model_input = latent_model_input.clone()
if conditioning is not None:
# conditioning is first frame conditioning for 2.2 i2v
latent_model_input = torch.cat(
[latent_model_input, conditioning], dim=1)
noise_pred = current_model(
hidden_states=latent_model_input,
@@ -235,7 +255,7 @@ class Wan22Pipeline(WanPipeline):
noise_pred, t, latents, return_dict=False)[0]
# apply i2v mask
latents = (latent_model_input * (1 - mask)) + (
latents = (pre_condition_latent_model_input * (1 - mask)) + (
latents * mask
)

View File

@@ -218,6 +218,33 @@ export const modelArchs: ModelArch[] = [
// '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors',
// },
},
{
name: 'wan22_14b_i2v',
label: 'Wan 2.2 I2V (14B)',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-I2V-A14B-Diffusers-bf16', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [16, 1],
'config.process[0].model.low_vram': [true, false],
'config.process[0].train.timestep_type': ['linear', 'sigmoid'],
'config.process[0].model.model_kwargs': [
{
train_high_noise: true,
train_low_noise: true,
},
{},
],
},
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage'],
},
{
name: 'wan22_5b',
label: 'Wan 2.2 TI2V (5B)',

View File

@@ -1 +1 @@
VERSION = "0.5.1"
VERSION = "0.5.2"