Add support for Wan2.2 5B

This commit is contained in:
Jaret Burkett
2025-07-29 05:31:54 -06:00
parent e55116d8c9
commit ca7c5c950b
11 changed files with 1241 additions and 92 deletions

View File

@@ -3,13 +3,15 @@ from .hidream import HidreamModel, HidreamE1Model
from .f_light import FLiteModel
from .omnigen2 import OmniGen2Model
from .flux_kontext import FluxKontextModel
from .wan22 import Wan22Model
AI_TOOLKIT_MODELS = [
# put a list of models here
ChromaModel,
HidreamModel,
HidreamE1Model,
FLiteModel,
OmniGen2Model,
FluxKontextModel
ChromaModel,
HidreamModel,
HidreamE1Model,
FLiteModel,
OmniGen2Model,
FluxKontextModel,
Wan22Model,
]

View File

@@ -0,0 +1 @@
from .wan22_model import Wan22Model

View File

@@ -0,0 +1,259 @@
import torch
from toolkit.prompt_utils import PromptEmbeds
from PIL import Image
from diffusers import UniPCMultistepScheduler
import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from .wan22_pipeline import Wan22Pipeline
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from torchvision.transforms import functional as TF
from toolkit.models.wan21.wan21 import Wan21, AggressiveWanUnloadPipeline
from toolkit.models.wan21.wan_utils import add_first_frame_conditioning_v22
# for generation only?
scheduler_configUniPC = {
"_class_name": "UniPCMultistepScheduler",
"_diffusers_version": "0.35.0.dev0",
"beta_end": 0.02,
"beta_schedule": "linear",
"beta_start": 0.0001,
"disable_corrector": [],
"dynamic_thresholding_ratio": 0.995,
"final_sigmas_type": "zero",
"flow_shift": 5.0,
"lower_order_final": True,
"num_train_timesteps": 1000,
"predict_x0": True,
"prediction_type": "flow_prediction",
"rescale_betas_zero_snr": False,
"sample_max_value": 1.0,
"solver_order": 2,
"solver_p": None,
"solver_type": "bh2",
"steps_offset": 0,
"thresholding": False,
"time_shift_type": "exponential",
"timestep_spacing": "linspace",
"trained_betas": None,
"use_beta_sigmas": False,
"use_dynamic_shifting": False,
"use_exponential_sigmas": False,
"use_flow_sigmas": True,
"use_karras_sigmas": False,
}
# for training. I think it is right
scheduler_config = {
"num_train_timesteps": 1000,
"shift": 5.0,
"use_dynamic_shifting": False,
}
class Wan22Model(Wan21):
arch = "wan22_5b"
_wan_generation_scheduler_config = scheduler_configUniPC
_wan_expand_timesteps = True
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device=device,
model_config=model_config,
dtype=dtype,
custom_pipeline=custom_pipeline,
noise_scheduler=noise_scheduler,
**kwargs,
)
self._wan_cache = None
def get_bucket_divisibility(self):
# 16x compression and 2x2 patch size
return 32
def get_generation_pipeline(self):
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
pipeline = Wan22Pipeline(
vae=self.vae,
transformer=self.model,
transformer_2=self.model,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
expand_timesteps=self._wan_expand_timesteps,
device=self.device_torch,
aggressive_offload=self.model_config.low_vram,
)
pipeline = pipeline.to(self.device_torch)
return pipeline
# static method to get the scheduler
@staticmethod
def get_train_scheduler():
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
return scheduler
def get_base_model_version(self):
return "wan_2.2_5b"
def generate_single_image(
self,
pipeline: AggressiveWanUnloadPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
# reactivate progress bar since this is slooooow
pipeline.set_progress_bar_config(disable=False)
num_frames = (
(gen_config.num_frames - 1) // 4
) * 4 + 1 # make sure it is divisible by 4 + 1
gen_config.num_frames = num_frames
height = gen_config.height
width = gen_config.width
noise_mask = None
if gen_config.ctrl_img is not None:
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
d = self.get_bucket_divisibility()
# make sure they are divisible by d
height = height // d * d
width = width // d * d
# resize the control image
control_img = control_img.resize((width, height), Image.LANCZOS)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = pipeline.prepare_latents(
1,
num_channels_latents,
height,
width,
gen_config.num_frames,
torch.float32,
self.device_torch,
generator,
None,
).to(self.torch_dtype)
first_frame_n1p1 = (
TF.to_tensor(control_img)
.unsqueeze(0)
.to(self.device_torch, dtype=self.torch_dtype)
* 2.0
- 1.0
) # normalize to [-1, 1]
gen_config.latents, noise_mask = add_first_frame_conditioning_v22(
latent_model_input=latents, first_frame=first_frame_n1p1, vae=self.vae
)
output = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype
),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype
),
height=height,
width=width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
num_frames=gen_config.num_frames,
generator=generator,
return_dict=False,
output_type="pil",
noise_mask=noise_mask,
**extra,
)[0]
# shape = [1, frames, channels, height, width]
batch_item = output[0] # list of pil images
if gen_config.num_frames > 1:
return batch_item # return the frames.
else:
# get just the first image
img = batch_item[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
batch: DataLoaderBatchDTO,
**kwargs,
):
# videos come in (bs, num_frames, channels, height, width)
# images come in (bs, channels, height, width)
# for wan, only do i2v for video for now. Images do normal t2i
conditioned_latent = latent_model_input
noise_mask = None
with torch.no_grad():
frames = batch.tensor
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
# Add conditioning using the standalone function
conditioned_latent, noise_mask = add_first_frame_conditioning_v22(
latent_model_input=latent_model_input.to(
self.device_torch, self.torch_dtype
),
first_frame=first_frames.to(self.device_torch, self.torch_dtype),
vae=self.vae,
)
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
# make the noise mask
if noise_mask is None:
noise_mask = torch.ones(
conditioned_latent.shape,
dtype=conditioned_latent.dtype,
device=conditioned_latent.device,
)
# todo write this better
t_chunks = torch.chunk(timestep, timestep.shape[0])
out_t_chunks = []
for t in t_chunks:
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten()
# batch_size, seq_len
temp_ts = temp_ts.unsqueeze(0)
out_t_chunks.append(temp_ts)
timestep = torch.cat(out_t_chunks, dim=0)
noise_pred = self.model(
hidden_states=conditioned_latent,
timestep=timestep,
encoder_hidden_states=text_embeddings.text_embeds,
return_dict=False,
**kwargs,
)[0]
return noise_pred

View File

@@ -0,0 +1,263 @@
import torch
from toolkit.basic import flush
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from typing import List
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from typing import Any, Callable, Dict, List, Optional, Union
class Wan22Pipeline(WanPipeline):
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer_2: Optional[WanTransformer3DModel] = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False, # Wan2.2 ti2v
device: torch.device = torch.device("cuda"),
aggressive_offload: bool = False,
):
super().__init__(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
transformer_2=transformer_2,
boundary_ratio=boundary_ratio,
expand_timesteps=expand_timesteps,
vae=vae,
scheduler=scheduler,
)
self._aggressive_offload = aggressive_offload
self._exec_device = device
@property
def _execution_device(self):
return self._exec_device
def __call__(
self: WanPipeline,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator,
List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None],
PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
noise_mask: Optional[torch.Tensor] = None,
):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# unload vae and transformer
vae_device = self.vae.device
transformer_device = self.transformer.device
text_encoder_device = self.text_encoder.device
device = self.transformer.device
if self._aggressive_offload:
print("Unloading vae")
self.vae.to("cpu")
self.text_encoder.to(device)
flush()
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
negative_prompt,
height,
width,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if self._aggressive_offload:
# unload text encoder
print("Unloading text encoder")
self.text_encoder.to("cpu")
self.transformer.to(device)
flush()
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(device, transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(
device, transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
mask = noise_mask
if mask is None:
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - \
num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(device, transformer_dtype)
if self.config.expand_timesteps:
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
# batch_size, seq_len
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
else:
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * \
(noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False)[0]
# apply i2v mask
latents = (latent_model_input * (1 - mask)) + (
latents * mask
)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(
self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop(
"prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if self._aggressive_offload:
# unload transformer
print("Unloading transformer")
self.transformer.to("cpu")
if self.transformer_2 is not None:
self.transformer_2.to("cpu")
# load vae
print("Loading Vae")
self.vae.to(vae_device)
flush()
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(
video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return WanPipelineOutput(frames=video)

View File

@@ -1,7 +1,7 @@
torchao==0.10.0
safetensors
git+https://github.com/jaretburkett/easy_dwpose.git
git+https://github.com/huggingface/diffusers@00f95b9755718aabb65456e791b8408526ae6e76
git+https://github.com/huggingface/diffusers@56d438727036b0918b30bbe3110c5fe1634ed19d
transformers==4.52.4
lycoris-lora==1.8.3
flatten_json

View File

@@ -20,6 +20,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin
from diffusers.utils import logging
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.activations import get_activation
@@ -34,6 +35,104 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CACHE_T = 2
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
if first_chunk:
x = x[:, :, self.factor_t - 1:, :, :]
return x
class WanCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
@@ -134,19 +233,23 @@ class WanResample(nn.Module):
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str) -> None:
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# default to dim //2
if upsample_out_dim is None:
upsample_out_dim = dim // 2
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
)
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
@@ -363,6 +466,48 @@ class WanMidBlock(nn.Module):
return x
class WanResidualDownBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
num_res_blocks,
temperal_downsample=False,
down_flag=False):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
resnets = []
for _ in range(num_res_blocks):
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
self.downsampler = WanResample(out_dim, mode=mode)
else:
self.downsampler = None
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x, feat_cache, feat_idx)
if self.downsampler is not None:
x = self.downsampler(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class WanEncoder3d(nn.Module):
r"""
A 3D encoder module.
@@ -380,6 +525,7 @@ class WanEncoder3d(nn.Module):
def __init__(
self,
in_channels: int = 3,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
@@ -388,6 +534,7 @@ class WanEncoder3d(nn.Module):
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
is_residual: bool = False, # wan 2.2 vae use a residual downblock
):
super().__init__()
self.dim = dim
@@ -403,23 +550,35 @@ class WanEncoder3d(nn.Module):
scale = 1.0
# init block
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
if is_residual:
self.down_blocks.append(
WanResidualDownBlock(
in_dim,
out_dim,
dropout,
num_res_blocks,
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
down_flag=i != len(dim_mult) - 1,
)
)
else:
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
@@ -469,6 +628,92 @@ class WanEncoder3d(nn.Module):
x = self.conv_out(x)
return x
class WanResidualUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
temperal_upsample (bool): Whether to upsample on temporal dimension
up_flag (bool): Whether to upsample or not
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
temperal_upsample: bool = False,
up_flag: bool = False,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2,
)
else:
self.avg_shortcut = None
# create residual blocks
resnets = []
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
if up_flag:
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
else:
self.upsampler = None
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
x_copy = x.clone()
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsampler is not None:
if feat_cache is not None:
x = self.upsampler(x, feat_cache, feat_idx)
else:
x = self.upsampler(x)
if self.avg_shortcut is not None:
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
return x
class WanUpBlock(nn.Module):
"""
@@ -513,7 +758,7 @@ class WanUpBlock(nn.Module):
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
"""
Forward pass through the upsampling block.
@@ -564,6 +809,8 @@ class WanDecoder3d(nn.Module):
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
out_channels: int = 3,
is_residual: bool = False,
):
super().__init__()
self.dim = dim
@@ -577,7 +824,6 @@ class WanDecoder3d(nn.Module):
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
@@ -589,36 +835,47 @@ class WanDecoder3d(nn.Module):
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i > 0:
if i > 0 and not is_residual:
# wan vae 2.1
in_dim = in_dim // 2
# Determine if we need upsampling
# determine if we need upsampling
up_flag = i != len(dim_mult) - 1
# determine upsampling mode, if not upsampling, set to None
upsample_mode = None
if i != len(dim_mult) - 1:
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
if up_flag and temperal_upsample[i]:
upsample_mode = "upsample3d"
elif up_flag:
upsample_mode = "upsample2d"
# Create and add the upsampling block
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
if is_residual:
up_block = WanResidualUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
temperal_upsample=temperal_upsample[i] if up_flag else False,
up_flag= up_flag,
non_linearity=non_linearity,
)
else:
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# Update scale for next iteration
if upsample_mode is not None:
scale *= 2.0
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
@@ -633,20 +890,11 @@ class WanDecoder3d(nn.Module):
x = self.conv_in(x)
## middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = self._gradient_checkpointing_func(up_block, x, feat_cache, feat_idx)
else:
x = self.mid_block(x, feat_cache, feat_idx)
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk)
## head
x = self.norm_out(x)
@@ -665,7 +913,46 @@ class WanDecoder3d(nn.Module):
return x
class AutoencoderKLWan(ModelMixin, ConfigMixin):
def patchify(x, patch_size):
# YiYi TODO: refactor this
from einops import rearrange
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b c f (h q) (w r) -> b (c r q) f h w",
q=patch_size,
r=patch_size,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
# YiYi TODO: refactor this
from einops import rearrange
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b (c r q) f h w -> b c f (h q) (w r)",
q=patch_size,
r=patch_size,
)
return x
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
@@ -674,12 +961,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
base_dim: int = 96,
decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
@@ -722,6 +1010,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
2.8251,
1.9160,
],
is_residual: bool = False,
in_channels: int = 3,
out_channels: int = 3,
patch_size: Optional[int] = None,
scale_factor_temporal: Optional[int] = 4,
scale_factor_spatial: Optional[int] = 8,
clip_output: bool = True,
) -> None:
super().__init__()
@@ -729,37 +1024,119 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
if decoder_base_dim is None:
decoder_base_dim = base_dim
self.encoder = WanEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual
)
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
self.decoder = WanDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual
)
def clear_cache(self):
def _count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, WanCausalConv3d):
count += 1
return count
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
self._conv_num = _count_conv3d(self.decoder)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
self._cached_conv_counts = {
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
if self.decoder is not None
else 0,
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
if self.encoder is not None
else 0,
}
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self):
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = _count_conv3d(self.encoder)
self._enc_conv_num = self._cached_conv_counts["encoder"]
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _encode(self, x: torch.Tensor) -> torch.Tensor:
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
@@ -773,8 +1150,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
out = torch.cat([out, out_], 2)
enc = self.quant_conv(out)
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
enc = torch.cat([mu, logvar], dim=1)
self.clear_cache()
return enc
@@ -794,27 +1169,39 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self._encode(x)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
self.clear_cache()
def _decode(self, z: torch.Tensor, return_dict: bool = True):
_, _, num_frame, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
iter_ = z.shape[2]
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
self.clear_cache()
x = self.post_quant_conv(z)
for i in range(iter_):
for i in range(num_frame):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
out = torch.clamp(out, min=-1.0, max=1.0)
if self.config.clip_output:
out = torch.clamp(out, min=-1.0, max=1.0)
if self.config.patch_size is not None:
out = unpatchify(out, patch_size=self.config.patch_size)
self.clear_cache()
if not return_dict:
return (out,)
@@ -836,12 +1223,161 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
decoded = self._decode(z).sample
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
self.clear_cache()
time = []
frame_range = 1 + (num_frames - 1) // 4
for k in range(frame_range):
self._enc_conv_idx = [0]
if k == 0:
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
else:
tile = x[
:,
:,
1 + 4 * (k - 1) : 1 + 4 * k,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
tile = self.quant_conv(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
self.clear_cache()
time = []
for k in range(num_frames):
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
@@ -862,4 +1398,4 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
return dec

View File

@@ -89,12 +89,18 @@ class AggressiveWanUnloadPipeline(WanPipeline):
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer_2: Optional[WanTransformer3DModel] = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False, # Wan2.2 ti2v
device: torch.device = torch.device("cuda"),
):
super().__init__(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
transformer_2=transformer_2,
boundary_ratio=boundary_ratio,
expand_timesteps=expand_timesteps,
vae=vae,
scheduler=scheduler,
)
@@ -300,6 +306,8 @@ class AggressiveWanUnloadPipeline(WanPipeline):
class Wan21(BaseModel):
arch = 'wan21'
_wan_generation_scheduler_config = scheduler_configUniPC
_wan_expand_timesteps = False
def __init__(
self,
device,
@@ -331,7 +339,7 @@ class Wan21(BaseModel):
dtype = self.torch_dtype
model_path = self.model_config.name_or_path
self.print_and_status_update("Loading Wan2.1 model")
self.print_and_status_update("Loading Wan model")
subfolder = 'transformer'
transformer_path = model_path
if os.path.exists(transformer_path):
@@ -380,7 +388,6 @@ class Wan21(BaseModel):
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = get_qtype(self.model_config.qtype)
self.print_and_status_update("Quantizing transformer")
if self.model_config.low_vram:
print("Quantizing blocks")
orig_exclude = copy.deepcopy(quantization_args['exclude'])
@@ -474,22 +481,26 @@ class Wan21(BaseModel):
self.tokenizer = tokenizer
def get_generation_pipeline(self):
scheduler = UniPCMultistepScheduler(**scheduler_configUniPC)
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
if self.model_config.low_vram:
pipeline = AggressiveWanUnloadPipeline(
vae=self.vae,
transformer=self.model,
transformer_2=self.model,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
expand_timesteps=self._wan_expand_timesteps,
device=self.device_torch
)
else:
pipeline = WanPipeline(
vae=self.vae,
transformer=self.unet,
transformer_2=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
expand_timesteps=self._wan_expand_timesteps,
scheduler=scheduler,
)

View File

@@ -48,11 +48,13 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
image_encoder: CLIPVisionModel,
image_processor: CLIPImageProcessor,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
image_processor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModel = None,
transformer_2: WanTransformer3DModel = None,
boundary_ratio: Optional[float] = None,
device: torch.device = torch.device("cuda"),
):
super().__init__(
@@ -63,6 +65,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
transformer=transformer,
scheduler=scheduler,
image_processor=image_processor,
transformer_2=transformer_2,
boundary_ratio=boundary_ratio,
)
self._exec_device = device

View File

@@ -39,7 +39,7 @@ def add_first_frame_conditioning(
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
# resize first frame to match the latent model input
vae_scale_factor = 8
vae_scale_factor = vae.config.scale_factor_spatial
first_frame = F.interpolate(
first_frame,
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
@@ -111,3 +111,55 @@ def add_first_frame_conditioning(
[latent_model_input, first_frame_condition], dim=1)
return conditioned_latent
def add_first_frame_conditioning_v22(
latent_model_input,
first_frame,
vae
):
"""
Overwrites first few time steps in latent_model_input with VAE-encoded first_frame,
and returns the modified latent + binary mask (0=conditioned, 1=noise).
Args:
latent_model_input: torch.Tensor of shape (bs, 48, T, H, W)
first_frame: torch.Tensor of shape (bs, 3, H*scale, W*scale)
vae: VAE model with .encode() and .config.latents_mean/std
Returns:
latent: (bs, 48, T, H, W) - modified input latent
mask: (bs, 1, T, H, W) - binary mask
"""
device = latent_model_input.device
dtype = latent_model_input.dtype
bs, _, T, H, W = latent_model_input.shape
scale = vae.config.scale_factor_spatial
target_h = H * scale
target_w = W * scale
# Ensure shape
if first_frame.ndim == 3:
first_frame = first_frame.unsqueeze(0)
if first_frame.shape[0] != bs:
first_frame = first_frame.expand(bs, -1, -1, -1)
# Resize and encode
first_frame_up = F.interpolate(first_frame, size=(target_h, target_w), mode="bilinear", align_corners=False)
first_frame_up = first_frame_up.unsqueeze(2) # (bs, 3, 1, H, W)
encoded = vae.encode(first_frame_up).latent_dist.sample().to(dtype).to(device)
# Normalize
mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
encoded = (encoded - mean) * std
# Replace in latent
latent = latent_model_input.clone()
latent[:, :, :encoded.shape[2]] = encoded # typically first frame: [:, :, 0]
# Mask: 0 where conditioned, 1 otherwise
mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype)
mask[:, :, :encoded.shape[2]] = 0.0
return latent, mask

View File

@@ -181,6 +181,27 @@ export const modelArchs: ModelArch[] = [
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames', 'model.low_vram'],
},
{
name: 'wan22_5b',
label: 'Wan 2.2 TI2V (5B)',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.2-TI2V-5B-Diffusers', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.low_vram': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [121, 1],
'config.process[0].sample.fps': [24, 1],
'config.process[0].sample.width': [768, 1024],
'config.process[0].sample.height': [768, 1024],
},
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
},
{
name: 'lumina2',
label: 'Lumina2',

View File

@@ -1 +1 @@
VERSION = "0.3.13"
VERSION = "0.3.14"