mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Add support for Wan2.2 5B
This commit is contained in:
@@ -39,7 +39,7 @@ def add_first_frame_conditioning(
|
||||
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
|
||||
|
||||
# resize first frame to match the latent model input
|
||||
vae_scale_factor = 8
|
||||
vae_scale_factor = vae.config.scale_factor_spatial
|
||||
first_frame = F.interpolate(
|
||||
first_frame,
|
||||
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
|
||||
@@ -111,3 +111,55 @@ def add_first_frame_conditioning(
|
||||
[latent_model_input, first_frame_condition], dim=1)
|
||||
|
||||
return conditioned_latent
|
||||
|
||||
|
||||
def add_first_frame_conditioning_v22(
|
||||
latent_model_input,
|
||||
first_frame,
|
||||
vae
|
||||
):
|
||||
"""
|
||||
Overwrites first few time steps in latent_model_input with VAE-encoded first_frame,
|
||||
and returns the modified latent + binary mask (0=conditioned, 1=noise).
|
||||
|
||||
Args:
|
||||
latent_model_input: torch.Tensor of shape (bs, 48, T, H, W)
|
||||
first_frame: torch.Tensor of shape (bs, 3, H*scale, W*scale)
|
||||
vae: VAE model with .encode() and .config.latents_mean/std
|
||||
|
||||
Returns:
|
||||
latent: (bs, 48, T, H, W) - modified input latent
|
||||
mask: (bs, 1, T, H, W) - binary mask
|
||||
"""
|
||||
device = latent_model_input.device
|
||||
dtype = latent_model_input.dtype
|
||||
bs, _, T, H, W = latent_model_input.shape
|
||||
scale = vae.config.scale_factor_spatial
|
||||
target_h = H * scale
|
||||
target_w = W * scale
|
||||
|
||||
# Ensure shape
|
||||
if first_frame.ndim == 3:
|
||||
first_frame = first_frame.unsqueeze(0)
|
||||
if first_frame.shape[0] != bs:
|
||||
first_frame = first_frame.expand(bs, -1, -1, -1)
|
||||
|
||||
# Resize and encode
|
||||
first_frame_up = F.interpolate(first_frame, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
||||
first_frame_up = first_frame_up.unsqueeze(2) # (bs, 3, 1, H, W)
|
||||
encoded = vae.encode(first_frame_up).latent_dist.sample().to(dtype).to(device)
|
||||
|
||||
# Normalize
|
||||
mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
|
||||
std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
|
||||
encoded = (encoded - mean) * std
|
||||
|
||||
# Replace in latent
|
||||
latent = latent_model_input.clone()
|
||||
latent[:, :, :encoded.shape[2]] = encoded # typically first frame: [:, :, 0]
|
||||
|
||||
# Mask: 0 where conditioned, 1 otherwise
|
||||
mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype)
|
||||
mask[:, :, :encoded.shape[2]] = 0.0
|
||||
|
||||
return latent, mask
|
||||
Reference in New Issue
Block a user