mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Initial work for training wan first and last frame
This commit is contained in:
@@ -116,7 +116,8 @@ def add_first_frame_conditioning(
|
|||||||
def add_first_frame_conditioning_v22(
|
def add_first_frame_conditioning_v22(
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
first_frame,
|
first_frame,
|
||||||
vae
|
vae,
|
||||||
|
last_frame=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Overwrites first few time steps in latent_model_input with VAE-encoded first_frame,
|
Overwrites first few time steps in latent_model_input with VAE-encoded first_frame,
|
||||||
@@ -161,5 +162,16 @@ def add_first_frame_conditioning_v22(
|
|||||||
# Mask: 0 where conditioned, 1 otherwise
|
# Mask: 0 where conditioned, 1 otherwise
|
||||||
mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype)
|
mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype)
|
||||||
mask[:, :, :encoded.shape[2]] = 0.0
|
mask[:, :, :encoded.shape[2]] = 0.0
|
||||||
|
|
||||||
|
if last_frame is not None:
|
||||||
|
# If last_frame is provided, encode it similarly
|
||||||
|
last_frame_up = F.interpolate(last_frame, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
||||||
|
last_frame_up = last_frame_up.unsqueeze(2)
|
||||||
|
last_encoded = vae.encode(last_frame_up).latent_dist.sample().to(dtype).to(device)
|
||||||
|
last_encoded = (last_encoded - mean) * std
|
||||||
|
latent[:, :, -last_encoded.shape[2]:] = last_encoded # replace last
|
||||||
|
mask[:, :, -last_encoded.shape[2]:] = 0.0 #
|
||||||
|
# Ensure mask is still binary
|
||||||
|
mask = mask.clamp(0.0, 1.0)
|
||||||
|
|
||||||
return latent, mask
|
return latent, mask
|
||||||
Reference in New Issue
Block a user