Add support for Wan2.2 5B

This commit is contained in:
Jaret Burkett
2025-07-29 05:31:54 -06:00
parent e55116d8c9
commit ca7c5c950b
11 changed files with 1241 additions and 92 deletions

View File

@@ -20,6 +20,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin
from diffusers.utils import logging
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.activations import get_activation
@@ -34,6 +35,104 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CACHE_T = 2
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
if first_chunk:
x = x[:, :, self.factor_t - 1:, :, :]
return x
class WanCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
@@ -134,19 +233,23 @@ class WanResample(nn.Module):
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str) -> None:
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# default to dim //2
if upsample_out_dim is None:
upsample_out_dim = dim // 2
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
)
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
@@ -363,6 +466,48 @@ class WanMidBlock(nn.Module):
return x
class WanResidualDownBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
num_res_blocks,
temperal_downsample=False,
down_flag=False):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
resnets = []
for _ in range(num_res_blocks):
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
self.downsampler = WanResample(out_dim, mode=mode)
else:
self.downsampler = None
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x, feat_cache, feat_idx)
if self.downsampler is not None:
x = self.downsampler(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class WanEncoder3d(nn.Module):
r"""
A 3D encoder module.
@@ -380,6 +525,7 @@ class WanEncoder3d(nn.Module):
def __init__(
self,
in_channels: int = 3,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
@@ -388,6 +534,7 @@ class WanEncoder3d(nn.Module):
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
is_residual: bool = False, # wan 2.2 vae use a residual downblock
):
super().__init__()
self.dim = dim
@@ -403,23 +550,35 @@ class WanEncoder3d(nn.Module):
scale = 1.0
# init block
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
if is_residual:
self.down_blocks.append(
WanResidualDownBlock(
in_dim,
out_dim,
dropout,
num_res_blocks,
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
down_flag=i != len(dim_mult) - 1,
)
)
else:
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
@@ -469,6 +628,92 @@ class WanEncoder3d(nn.Module):
x = self.conv_out(x)
return x
class WanResidualUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
temperal_upsample (bool): Whether to upsample on temporal dimension
up_flag (bool): Whether to upsample or not
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
temperal_upsample: bool = False,
up_flag: bool = False,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2,
)
else:
self.avg_shortcut = None
# create residual blocks
resnets = []
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
if up_flag:
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
else:
self.upsampler = None
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
x_copy = x.clone()
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsampler is not None:
if feat_cache is not None:
x = self.upsampler(x, feat_cache, feat_idx)
else:
x = self.upsampler(x)
if self.avg_shortcut is not None:
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
return x
class WanUpBlock(nn.Module):
"""
@@ -513,7 +758,7 @@ class WanUpBlock(nn.Module):
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
"""
Forward pass through the upsampling block.
@@ -564,6 +809,8 @@ class WanDecoder3d(nn.Module):
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
out_channels: int = 3,
is_residual: bool = False,
):
super().__init__()
self.dim = dim
@@ -577,7 +824,6 @@ class WanDecoder3d(nn.Module):
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
@@ -589,36 +835,47 @@ class WanDecoder3d(nn.Module):
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i > 0:
if i > 0 and not is_residual:
# wan vae 2.1
in_dim = in_dim // 2
# Determine if we need upsampling
# determine if we need upsampling
up_flag = i != len(dim_mult) - 1
# determine upsampling mode, if not upsampling, set to None
upsample_mode = None
if i != len(dim_mult) - 1:
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
if up_flag and temperal_upsample[i]:
upsample_mode = "upsample3d"
elif up_flag:
upsample_mode = "upsample2d"
# Create and add the upsampling block
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
if is_residual:
up_block = WanResidualUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
temperal_upsample=temperal_upsample[i] if up_flag else False,
up_flag= up_flag,
non_linearity=non_linearity,
)
else:
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# Update scale for next iteration
if upsample_mode is not None:
scale *= 2.0
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
@@ -633,20 +890,11 @@ class WanDecoder3d(nn.Module):
x = self.conv_in(x)
## middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = self._gradient_checkpointing_func(up_block, x, feat_cache, feat_idx)
else:
x = self.mid_block(x, feat_cache, feat_idx)
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk)
## head
x = self.norm_out(x)
@@ -665,7 +913,46 @@ class WanDecoder3d(nn.Module):
return x
class AutoencoderKLWan(ModelMixin, ConfigMixin):
def patchify(x, patch_size):
# YiYi TODO: refactor this
from einops import rearrange
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b c f (h q) (w r) -> b (c r q) f h w",
q=patch_size,
r=patch_size,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
# YiYi TODO: refactor this
from einops import rearrange
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b (c r q) f h w -> b c f (h q) (w r)",
q=patch_size,
r=patch_size,
)
return x
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
@@ -674,12 +961,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
base_dim: int = 96,
decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
@@ -722,6 +1010,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
2.8251,
1.9160,
],
is_residual: bool = False,
in_channels: int = 3,
out_channels: int = 3,
patch_size: Optional[int] = None,
scale_factor_temporal: Optional[int] = 4,
scale_factor_spatial: Optional[int] = 8,
clip_output: bool = True,
) -> None:
super().__init__()
@@ -729,37 +1024,119 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
if decoder_base_dim is None:
decoder_base_dim = base_dim
self.encoder = WanEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual
)
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
self.decoder = WanDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual
)
def clear_cache(self):
def _count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, WanCausalConv3d):
count += 1
return count
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
self._conv_num = _count_conv3d(self.decoder)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
self._cached_conv_counts = {
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
if self.decoder is not None
else 0,
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
if self.encoder is not None
else 0,
}
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self):
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = _count_conv3d(self.encoder)
self._enc_conv_num = self._cached_conv_counts["encoder"]
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _encode(self, x: torch.Tensor) -> torch.Tensor:
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
@@ -773,8 +1150,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
out = torch.cat([out, out_], 2)
enc = self.quant_conv(out)
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
enc = torch.cat([mu, logvar], dim=1)
self.clear_cache()
return enc
@@ -794,27 +1169,39 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self._encode(x)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
self.clear_cache()
def _decode(self, z: torch.Tensor, return_dict: bool = True):
_, _, num_frame, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
iter_ = z.shape[2]
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
self.clear_cache()
x = self.post_quant_conv(z)
for i in range(iter_):
for i in range(num_frame):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
out = torch.clamp(out, min=-1.0, max=1.0)
if self.config.clip_output:
out = torch.clamp(out, min=-1.0, max=1.0)
if self.config.patch_size is not None:
out = unpatchify(out, patch_size=self.config.patch_size)
self.clear_cache()
if not return_dict:
return (out,)
@@ -836,12 +1223,161 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
decoded = self._decode(z).sample
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
self.clear_cache()
time = []
frame_range = 1 + (num_frames - 1) // 4
for k in range(frame_range):
self._enc_conv_idx = [0]
if k == 0:
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
else:
tile = x[
:,
:,
1 + 4 * (k - 1) : 1 + 4 * k,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
tile = self.quant_conv(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
self.clear_cache()
time = []
for k in range(num_frames):
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
@@ -862,4 +1398,4 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin):
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
return dec

View File

@@ -89,12 +89,18 @@ class AggressiveWanUnloadPipeline(WanPipeline):
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer_2: Optional[WanTransformer3DModel] = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False, # Wan2.2 ti2v
device: torch.device = torch.device("cuda"),
):
super().__init__(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
transformer_2=transformer_2,
boundary_ratio=boundary_ratio,
expand_timesteps=expand_timesteps,
vae=vae,
scheduler=scheduler,
)
@@ -300,6 +306,8 @@ class AggressiveWanUnloadPipeline(WanPipeline):
class Wan21(BaseModel):
arch = 'wan21'
_wan_generation_scheduler_config = scheduler_configUniPC
_wan_expand_timesteps = False
def __init__(
self,
device,
@@ -331,7 +339,7 @@ class Wan21(BaseModel):
dtype = self.torch_dtype
model_path = self.model_config.name_or_path
self.print_and_status_update("Loading Wan2.1 model")
self.print_and_status_update("Loading Wan model")
subfolder = 'transformer'
transformer_path = model_path
if os.path.exists(transformer_path):
@@ -380,7 +388,6 @@ class Wan21(BaseModel):
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = get_qtype(self.model_config.qtype)
self.print_and_status_update("Quantizing transformer")
if self.model_config.low_vram:
print("Quantizing blocks")
orig_exclude = copy.deepcopy(quantization_args['exclude'])
@@ -474,22 +481,26 @@ class Wan21(BaseModel):
self.tokenizer = tokenizer
def get_generation_pipeline(self):
scheduler = UniPCMultistepScheduler(**scheduler_configUniPC)
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
if self.model_config.low_vram:
pipeline = AggressiveWanUnloadPipeline(
vae=self.vae,
transformer=self.model,
transformer_2=self.model,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
expand_timesteps=self._wan_expand_timesteps,
device=self.device_torch
)
else:
pipeline = WanPipeline(
vae=self.vae,
transformer=self.unet,
transformer_2=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
expand_timesteps=self._wan_expand_timesteps,
scheduler=scheduler,
)

View File

@@ -48,11 +48,13 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
image_encoder: CLIPVisionModel,
image_processor: CLIPImageProcessor,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
image_processor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModel = None,
transformer_2: WanTransformer3DModel = None,
boundary_ratio: Optional[float] = None,
device: torch.device = torch.device("cuda"),
):
super().__init__(
@@ -63,6 +65,8 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline):
transformer=transformer,
scheduler=scheduler,
image_processor=image_processor,
transformer_2=transformer_2,
boundary_ratio=boundary_ratio,
)
self._exec_device = device

View File

@@ -39,7 +39,7 @@ def add_first_frame_conditioning(
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
# resize first frame to match the latent model input
vae_scale_factor = 8
vae_scale_factor = vae.config.scale_factor_spatial
first_frame = F.interpolate(
first_frame,
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
@@ -111,3 +111,55 @@ def add_first_frame_conditioning(
[latent_model_input, first_frame_condition], dim=1)
return conditioned_latent
def add_first_frame_conditioning_v22(
latent_model_input,
first_frame,
vae
):
"""
Overwrites first few time steps in latent_model_input with VAE-encoded first_frame,
and returns the modified latent + binary mask (0=conditioned, 1=noise).
Args:
latent_model_input: torch.Tensor of shape (bs, 48, T, H, W)
first_frame: torch.Tensor of shape (bs, 3, H*scale, W*scale)
vae: VAE model with .encode() and .config.latents_mean/std
Returns:
latent: (bs, 48, T, H, W) - modified input latent
mask: (bs, 1, T, H, W) - binary mask
"""
device = latent_model_input.device
dtype = latent_model_input.dtype
bs, _, T, H, W = latent_model_input.shape
scale = vae.config.scale_factor_spatial
target_h = H * scale
target_w = W * scale
# Ensure shape
if first_frame.ndim == 3:
first_frame = first_frame.unsqueeze(0)
if first_frame.shape[0] != bs:
first_frame = first_frame.expand(bs, -1, -1, -1)
# Resize and encode
first_frame_up = F.interpolate(first_frame, size=(target_h, target_w), mode="bilinear", align_corners=False)
first_frame_up = first_frame_up.unsqueeze(2) # (bs, 3, 1, H, W)
encoded = vae.encode(first_frame_up).latent_dist.sample().to(dtype).to(device)
# Normalize
mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
encoded = (encoded - mean) * std
# Replace in latent
latent = latent_model_input.clone()
latent[:, :, :encoded.shape[2]] = encoded # typically first frame: [:, :, 0]
# Mask: 0 where conditioned, 1 otherwise
mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype)
mask[:, :, :encoded.shape[2]] = 0.0
return latent, mask