diff --git a/toolkit/models/wan21/wan_utils.py b/toolkit/models/wan21/wan_utils.py index 3a300d2b..6755007a 100644 --- a/toolkit/models/wan21/wan_utils.py +++ b/toolkit/models/wan21/wan_utils.py @@ -116,7 +116,8 @@ def add_first_frame_conditioning( def add_first_frame_conditioning_v22( latent_model_input, first_frame, - vae + vae, + last_frame=None ): """ Overwrites first few time steps in latent_model_input with VAE-encoded first_frame, @@ -161,5 +162,16 @@ def add_first_frame_conditioning_v22( # Mask: 0 where conditioned, 1 otherwise mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype) mask[:, :, :encoded.shape[2]] = 0.0 + + if last_frame is not None: + # If last_frame is provided, encode it similarly + last_frame_up = F.interpolate(last_frame, size=(target_h, target_w), mode="bilinear", align_corners=False) + last_frame_up = last_frame_up.unsqueeze(2) + last_encoded = vae.encode(last_frame_up).latent_dist.sample().to(dtype).to(device) + last_encoded = (last_encoded - mean) * std + latent[:, :, -last_encoded.shape[2]:] = last_encoded # replace last + mask[:, :, -last_encoded.shape[2]:] = 0.0 # + # Ensure mask is still binary + mask = mask.clamp(0.0, 1.0) return latent, mask \ No newline at end of file