mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
177 lines
6.4 KiB
Python
177 lines
6.4 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def add_first_frame_conditioning(
|
|
latent_model_input,
|
|
first_frame,
|
|
vae
|
|
):
|
|
"""
|
|
Adds first frame conditioning to a video diffusion model input.
|
|
|
|
Args:
|
|
latent_model_input: Original latent input (bs, channels, num_frames, height, width)
|
|
first_frame: Tensor of first frame to condition on (bs, channels, height, width)
|
|
vae: VAE model for encoding the conditioning
|
|
|
|
Returns:
|
|
conditioned_latent: The complete conditioned latent input (bs, 36, num_frames, height, width)
|
|
"""
|
|
device = latent_model_input.device
|
|
dtype = latent_model_input.dtype
|
|
vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample)
|
|
|
|
# Get number of frames from latent model input
|
|
_, _, num_latent_frames, _, _ = latent_model_input.shape
|
|
|
|
# Calculate original number of frames
|
|
# For n original frames, there are (n-1)//4 + 1 latent frames
|
|
# So to get n: n = (num_latent_frames-1)*4 + 1
|
|
num_frames = (num_latent_frames - 1) * 4 + 1
|
|
|
|
if len(first_frame.shape) == 3:
|
|
# we have a single image
|
|
first_frame = first_frame.unsqueeze(0)
|
|
|
|
# if it doesnt match the batch size, we need to expand it
|
|
if first_frame.shape[0] != latent_model_input.shape[0]:
|
|
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
|
|
|
|
# resize first frame to match the latent model input
|
|
vae_scale_factor = vae.config.scale_factor_spatial
|
|
first_frame = F.interpolate(
|
|
first_frame,
|
|
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
|
|
mode='bilinear',
|
|
align_corners=False
|
|
)
|
|
|
|
# Add temporal dimension to first frame
|
|
first_frame = first_frame.unsqueeze(2)
|
|
|
|
# Create video condition with first frame and zeros for remaining frames
|
|
zero_frame = torch.zeros_like(first_frame)
|
|
video_condition = torch.cat([
|
|
first_frame,
|
|
*[zero_frame for _ in range(num_frames - 1)]
|
|
], dim=2)
|
|
|
|
# Prepare for VAE encoding (bs, channels, num_frames, height, width)
|
|
# video_condition = video_condition.permute(0, 2, 1, 3, 4)
|
|
|
|
# Encode with VAE
|
|
latent_condition = vae.encode(
|
|
video_condition.to(device, dtype)
|
|
).latent_dist.sample()
|
|
latent_condition = latent_condition.to(device, dtype)
|
|
|
|
latents_mean = (
|
|
torch.tensor(vae.config.latents_mean)
|
|
.view(1, vae.config.z_dim, 1, 1, 1)
|
|
.to(device, dtype)
|
|
)
|
|
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
|
|
device, dtype
|
|
)
|
|
latent_condition = (latent_condition - latents_mean) * latents_std
|
|
|
|
|
|
# Create mask: 1 for conditioning frames, 0 for frames to generate
|
|
batch_size = first_frame.shape[0]
|
|
latent_height = latent_condition.shape[3]
|
|
latent_width = latent_condition.shape[4]
|
|
|
|
# Initialize mask for all frames
|
|
mask_lat_size = torch.ones(
|
|
batch_size, 1, num_frames, latent_height, latent_width)
|
|
|
|
# Set all non-first frames to 0
|
|
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
|
|
|
# Special handling for first frame
|
|
first_frame_mask = mask_lat_size[:, :, 0:1]
|
|
first_frame_mask = torch.repeat_interleave(
|
|
first_frame_mask, dim=2, repeats=vae_scale_factor_temporal)
|
|
|
|
# Combine first frame mask with rest
|
|
mask_lat_size = torch.concat(
|
|
[first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
|
|
|
# Reshape and transpose for model input
|
|
mask_lat_size = mask_lat_size.view(
|
|
batch_size, -1, vae_scale_factor_temporal, latent_height, latent_width)
|
|
mask_lat_size = mask_lat_size.transpose(1, 2)
|
|
mask_lat_size = mask_lat_size.to(device, dtype)
|
|
|
|
# Combine conditioning with latent input
|
|
first_frame_condition = torch.concat(
|
|
[mask_lat_size, latent_condition], dim=1)
|
|
conditioned_latent = torch.cat(
|
|
[latent_model_input, first_frame_condition], dim=1)
|
|
|
|
return conditioned_latent
|
|
|
|
|
|
def add_first_frame_conditioning_v22(
|
|
latent_model_input,
|
|
first_frame,
|
|
vae,
|
|
last_frame=None
|
|
):
|
|
"""
|
|
Overwrites first few time steps in latent_model_input with VAE-encoded first_frame,
|
|
and returns the modified latent + binary mask (0=conditioned, 1=noise).
|
|
|
|
Args:
|
|
latent_model_input: torch.Tensor of shape (bs, 48, T, H, W)
|
|
first_frame: torch.Tensor of shape (bs, 3, H*scale, W*scale)
|
|
vae: VAE model with .encode() and .config.latents_mean/std
|
|
|
|
Returns:
|
|
latent: (bs, 48, T, H, W) - modified input latent
|
|
mask: (bs, 1, T, H, W) - binary mask
|
|
"""
|
|
device = latent_model_input.device
|
|
dtype = latent_model_input.dtype
|
|
bs, _, T, H, W = latent_model_input.shape
|
|
scale = vae.config.scale_factor_spatial
|
|
target_h = H * scale
|
|
target_w = W * scale
|
|
|
|
# Ensure shape
|
|
if first_frame.ndim == 3:
|
|
first_frame = first_frame.unsqueeze(0)
|
|
if first_frame.shape[0] != bs:
|
|
first_frame = first_frame.expand(bs, -1, -1, -1)
|
|
|
|
# Resize and encode
|
|
first_frame_up = F.interpolate(first_frame, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
|
first_frame_up = first_frame_up.unsqueeze(2) # (bs, 3, 1, H, W)
|
|
encoded = vae.encode(first_frame_up).latent_dist.sample().to(dtype).to(device)
|
|
|
|
# Normalize
|
|
mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
|
|
std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
|
|
encoded = (encoded - mean) * std
|
|
|
|
# Replace in latent
|
|
latent = latent_model_input.clone()
|
|
latent[:, :, :encoded.shape[2]] = encoded # typically first frame: [:, :, 0]
|
|
|
|
# Mask: 0 where conditioned, 1 otherwise
|
|
mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype)
|
|
mask[:, :, :encoded.shape[2]] = 0.0
|
|
|
|
if last_frame is not None:
|
|
# If last_frame is provided, encode it similarly
|
|
last_frame_up = F.interpolate(last_frame, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
|
last_frame_up = last_frame_up.unsqueeze(2)
|
|
last_encoded = vae.encode(last_frame_up).latent_dist.sample().to(dtype).to(device)
|
|
last_encoded = (last_encoded - mean) * std
|
|
latent[:, :, -last_encoded.shape[2]:] = last_encoded # replace last
|
|
mask[:, :, -last_encoded.shape[2]:] = 0.0 #
|
|
# Ensure mask is still binary
|
|
mask = mask.clamp(0.0, 1.0)
|
|
|
|
return latent, mask |