mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add support for fine tuning Wan 2.2 I2V 14B
This commit is contained in:
@@ -3,7 +3,7 @@ from .hidream import HidreamModel, HidreamE1Model
|
||||
from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
from .flux_kontext import FluxKontextModel
|
||||
from .wan22 import Wan225bModel, Wan2214bModel
|
||||
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
|
||||
from .qwen_image import QwenImageModel
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
@@ -15,6 +15,7 @@ AI_TOOLKIT_MODELS = [
|
||||
OmniGen2Model,
|
||||
FluxKontextModel,
|
||||
Wan225bModel,
|
||||
Wan2214bI2VModel,
|
||||
Wan2214bModel,
|
||||
QwenImageModel,
|
||||
]
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .wan22_5b_model import Wan225bModel
|
||||
from .wan22_14b_model import Wan2214bModel
|
||||
from .wan22_14b_model import Wan2214bModel
|
||||
from .wan22_14b_i2v_model import Wan2214bI2VModel
|
||||
@@ -0,0 +1,144 @@
|
||||
import torch
|
||||
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from PIL import Image
|
||||
import torch
|
||||
from toolkit.config_modules import GenerateImageConfig
|
||||
from .wan22_pipeline import Wan22Pipeline
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from diffusers import WanImageToVideoPipeline
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from .wan22_14b_model import Wan2214bModel
|
||||
|
||||
class Wan2214bI2VModel(Wan2214bModel):
|
||||
arch = "wan22_14b_i2v"
|
||||
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: Wan22Pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
|
||||
# todo
|
||||
# reactivate progress bar since this is slooooow
|
||||
pipeline.set_progress_bar_config(disable=False)
|
||||
|
||||
num_frames = (
|
||||
(gen_config.num_frames - 1) // 4
|
||||
) * 4 + 1 # make sure it is divisible by 4 + 1
|
||||
gen_config.num_frames = num_frames
|
||||
|
||||
height = gen_config.height
|
||||
width = gen_config.width
|
||||
first_frame_n1p1 = None
|
||||
if gen_config.ctrl_img is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
||||
|
||||
d = self.get_bucket_divisibility()
|
||||
|
||||
# make sure they are divisible by d
|
||||
height = height // d * d
|
||||
width = width // d * d
|
||||
|
||||
# resize the control image
|
||||
control_img = control_img.resize((width, height), Image.LANCZOS)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
# num_channels_latents = self.transformer.config.in_channels
|
||||
num_channels_latents = 16
|
||||
latents = pipeline.prepare_latents(
|
||||
1,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
gen_config.num_frames,
|
||||
torch.float32,
|
||||
self.device_torch,
|
||||
generator,
|
||||
None,
|
||||
).to(self.torch_dtype)
|
||||
|
||||
first_frame_n1p1 = (
|
||||
TF.to_tensor(control_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.device_torch, dtype=self.torch_dtype)
|
||||
* 2.0
|
||||
- 1.0
|
||||
) # normalize to [-1, 1]
|
||||
|
||||
# Add conditioning using the standalone function
|
||||
gen_config.latents = add_first_frame_conditioning(
|
||||
latent_model_input=latents,
|
||||
first_frame=first_frame_n1p1,
|
||||
vae=self.vae
|
||||
)
|
||||
|
||||
output = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype
|
||||
),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype
|
||||
),
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
num_frames=gen_config.num_frames,
|
||||
generator=generator,
|
||||
return_dict=False,
|
||||
output_type="pil",
|
||||
**extra,
|
||||
)[0]
|
||||
|
||||
# shape = [1, frames, channels, height, width]
|
||||
batch_item = output[0] # list of pil images
|
||||
if gen_config.num_frames > 1:
|
||||
return batch_item # return the frames.
|
||||
else:
|
||||
# get just the first image
|
||||
img = batch_item[0]
|
||||
return img
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
batch: DataLoaderBatchDTO,
|
||||
**kwargs
|
||||
):
|
||||
# videos come in (bs, num_frames, channels, height, width)
|
||||
# images come in (bs, channels, height, width)
|
||||
with torch.no_grad():
|
||||
frames = batch.tensor
|
||||
if len(frames.shape) == 4:
|
||||
first_frames = frames
|
||||
elif len(frames.shape) == 5:
|
||||
first_frames = frames[:, 0]
|
||||
else:
|
||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||
|
||||
# Add conditioning using the standalone function
|
||||
conditioned_latent = add_first_frame_conditioning(
|
||||
latent_model_input=latent_model_input,
|
||||
first_frame=first_frames,
|
||||
vae=self.vae
|
||||
)
|
||||
|
||||
noise_pred = self.model(
|
||||
hidden_states=conditioned_latent,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
return_dict=False,
|
||||
**kwargs
|
||||
)[0]
|
||||
return noise_pred
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
import yaml
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from toolkit.basic import flush
|
||||
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from PIL import Image
|
||||
from diffusers import UniPCMultistepScheduler
|
||||
@@ -21,11 +22,10 @@ from diffusers import WanTransformer3DModel
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline, Wan21
|
||||
from toolkit.models.wan21.wan21 import Wan21
|
||||
from .wan22_5b_model import (
|
||||
scheduler_config,
|
||||
time_text_monkeypatch,
|
||||
Wan225bModel,
|
||||
)
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
@@ -239,8 +239,8 @@ class Wan2214bModel(Wan21):
|
||||
)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
# 16x compression and 2x2 patch size
|
||||
return 32
|
||||
# 8x compression and 2x2 patch size
|
||||
return 16
|
||||
|
||||
def load_wan_transformer(self, transformer_path, subfolder=None):
|
||||
if self.model_config.split_model_over_gpus:
|
||||
@@ -378,7 +378,7 @@ class Wan2214bModel(Wan21):
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: AggressiveWanUnloadPipeline,
|
||||
pipeline: Wan22Pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
@@ -513,7 +513,7 @@ class Wan2214bModel(Wan21):
|
||||
combined_dict[new_key] = low_noise_lora[key]
|
||||
|
||||
# if we are not training both stages, we wont have transformer designations in the keys
|
||||
if not self.train_high_noise and not self.train_low_noise:
|
||||
if not self.train_high_noise or not self.train_low_noise:
|
||||
new_dict = {}
|
||||
for key in combined_dict:
|
||||
if ".transformer_1." in key:
|
||||
|
||||
@@ -10,6 +10,7 @@ from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
||||
from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
|
||||
|
||||
class Wan22Pipeline(WanPipeline):
|
||||
@@ -149,6 +150,18 @@ class Wan22Pipeline(WanPipeline):
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
|
||||
conditioning = None # wan2.2 i2v conditioning
|
||||
# check shape of latents to see if it is first frame conditioned for 2.2 14b i2v
|
||||
if latents is not None:
|
||||
if latents.shape[1] == 36:
|
||||
# first 16 channels are latent. other 20 are conditioning
|
||||
conditioning = latents[:, 16:]
|
||||
latents = latents[:, :16]
|
||||
|
||||
# we need to trick the in_channls to think it is only 16 channels
|
||||
num_channels_latents = 16
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -210,6 +223,13 @@ class Wan22Pipeline(WanPipeline):
|
||||
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
||||
else:
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
pre_condition_latent_model_input = latent_model_input.clone()
|
||||
|
||||
if conditioning is not None:
|
||||
# conditioning is first frame conditioning for 2.2 i2v
|
||||
latent_model_input = torch.cat(
|
||||
[latent_model_input, conditioning], dim=1)
|
||||
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
@@ -235,7 +255,7 @@ class Wan22Pipeline(WanPipeline):
|
||||
noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
# apply i2v mask
|
||||
latents = (latent_model_input * (1 - mask)) + (
|
||||
latents = (pre_condition_latent_model_input * (1 - mask)) + (
|
||||
latents * mask
|
||||
)
|
||||
|
||||
|
||||
@@ -218,6 +218,33 @@ export const modelArchs: ModelArch[] = [
|
||||
// '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors',
|
||||
// },
|
||||
},
|
||||
{
|
||||
name: 'wan22_14b_i2v',
|
||||
label: 'Wan 2.2 I2V (14B)',
|
||||
group: 'video',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-I2V-A14B-Diffusers-bf16', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
'config.process[0].model.low_vram': [true, false],
|
||||
'config.process[0].train.timestep_type': ['linear', 'sigmoid'],
|
||||
'config.process[0].model.model_kwargs': [
|
||||
{
|
||||
train_high_noise: true,
|
||||
train_low_noise: true,
|
||||
},
|
||||
{},
|
||||
],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage'],
|
||||
},
|
||||
{
|
||||
name: 'wan22_5b',
|
||||
label: 'Wan 2.2 TI2V (5B)',
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.5.1"
|
||||
VERSION = "0.5.2"
|
||||
Reference in New Issue
Block a user