Do caching of latents, first frame and audio when caching latents for LTX2

This commit is contained in:
Jaret Burkett
2026-01-14 11:05:23 -07:00
parent 64fe29b182
commit 73dedbf662
6 changed files with 324 additions and 127 deletions

View File

@@ -98,6 +98,41 @@ def blank_log_image_function(self, *args, **kwargs):
return
class ComboVae(torch.nn.Module):
"""Combines video and audio VAEs for joint encoding and decoding."""
def __init__(
self,
vae: AutoencoderKLLTX2Video,
audio_vae: AutoencoderKLLTX2Audio,
) -> None:
super().__init__()
self.vae = vae
self.audio_vae = audio_vae
@property
def device(self):
return self.vae.device
@property
def dtype(self):
return self.vae.dtype
def encode(
self,
*args,
**kwargs,
):
return self.vae.encode(*args, **kwargs)
def decode(
self,
*args,
**kwargs,
):
return self.vae.decode(*args, **kwargs)
class AudioProcessor(torch.nn.Module):
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
@@ -445,7 +480,7 @@ class LTX2Model(BaseModel):
flush()
# save it to the model class
self.vae = vae
self.vae = ComboVae(pipe.vae, pipe.audio_vae)
self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer
@@ -467,10 +502,10 @@ class LTX2Model(BaseModel):
if dtype is None:
dtype = self.vae_torch_dtype
if self.vae.device == torch.device("cpu"):
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
if self.pipeline.vae.device == torch.device("cpu"):
self.pipeline.vae.to(device)
self.pipeline.vae.eval()
self.pipeline.vae.requires_grad_(False)
image_list = [image.to(device, dtype=dtype) for image in image_list]
@@ -489,7 +524,7 @@ class LTX2Model(BaseModel):
# Stack to (B, C, T, H, W)
images = torch.stack(norm_images)
latents = self.vae.encode(images).latent_dist.mode()
latents = self.pipeline.vae.encode(images).latent_dist.mode()
# Normalize latents across the channel dimension [B, C, F, H, W]
scaling_factor = 1.0
@@ -534,7 +569,7 @@ class LTX2Model(BaseModel):
):
if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch)
# handle control image
if gen_config.ctrl_img is not None:
# switch to image to video pipeline
@@ -566,12 +601,14 @@ class LTX2Model(BaseModel):
bd = self.get_bucket_divisibility()
gen_config.height = (gen_config.height // bd) * bd
gen_config.width = (gen_config.width // bd) * bd
# handle control image
if gen_config.ctrl_img is not None:
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
# resize the control image
control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS)
control_img = control_img.resize(
(gen_config.width, gen_config.height), Image.LANCZOS
)
# add the control image to the extra dict
extra["image"] = control_img
@@ -642,7 +679,8 @@ class LTX2Model(BaseModel):
img = video[0]
return img
def encode_audio(self, batch: "DataLoaderBatchDTO"):
def encode_audio(self, audio_data_list):
# audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)}
if self.pipeline.audio_vae.device == torch.device("cpu"):
self.pipeline.audio_vae.to(self.device_torch)
@@ -650,7 +688,7 @@ class LTX2Model(BaseModel):
audio_num_frames = None
# do them seperatly for now
for audio_data in batch.audio_data:
for audio_data in audio_data_list:
waveform = audio_data["waveform"].to(
device=self.device_torch, dtype=torch.float32
)
@@ -659,26 +697,30 @@ class LTX2Model(BaseModel):
# Add batch dimension if needed: [channels, samples] -> [batch, channels, samples]
if waveform.dim() == 2:
waveform = waveform.unsqueeze(0)
if waveform.shape[1] == 1:
# make sure it is stereo
waveform = waveform.repeat(1, 2, 1)
# Convert waveform to mel spectrogram using AudioProcessor
mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate)
mel_spectrogram = self.audio_processor.waveform_to_mel(
waveform, waveform_sample_rate=sample_rate
)
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
# Encode mel spectrogram to latents
latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode()
latents = self.pipeline.audio_vae.encode(
mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)
).latent_dist.mode()
if audio_num_frames is None:
audio_num_frames = latents.shape[2] #(latents is [B, C, T, F])
audio_num_frames = latents.shape[2] # (latents is [B, C, T, F])
packed_latents = self.pipeline._pack_audio_latents(
latents,
# patch_size=self.pipeline.transformer.config.audio_patch_size,
# patch_size_t=self.pipeline.transformer.config.audio_patch_size_t,
) # [B, L, C * M]
) # [B, L, C * M]
if output_tensor is None:
output_tensor = packed_latents
else:
@@ -688,7 +730,7 @@ class LTX2Model(BaseModel):
latents_mean = self.pipeline.audio_vae.latents_mean
latents_std = self.pipeline.audio_vae.latents_std
output_tensor = (output_tensor - latents_mean) / latents_std
return output_tensor, audio_num_frames
return output_tensor
def get_noise_prediction(
self,
@@ -701,38 +743,57 @@ class LTX2Model(BaseModel):
with torch.no_grad():
if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch)
batch_size, C, latent_num_frames, latent_height, latent_width = (
latent_model_input.shape
)
video_timestep = timestep.clone()
# i2v from first frame
if batch.dataset_config.do_i2v:
# videos come in (bs, num_frames, channels, height, width)
# images come in (bs, channels, height, width)
frames = batch.tensor
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
# check to see if we had it cached
if batch.first_frame_latents is not None:
init_latents = batch.first_frame_latents.to(
self.device_torch, dtype=self.torch_dtype
)
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
# first frame doesnt have time dim, add it back
init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype)
# extract the first frame and encode it
# videos come in (bs, num_frames, channels, height, width)
# images come in (bs, channels, height, width)
frames = batch.tensor
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
# first frame doesnt have time dim, add it back
init_latents = self.encode_images(
first_frames, device=self.device_torch, dtype=self.torch_dtype
)
# expand the latents to match video frames
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
mask_shape = (
batch_size,
1,
latent_num_frames,
latent_height,
latent_width,
)
# First condition is image latents and those should be kept clean.
conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype)
conditioning_mask = torch.zeros(
mask_shape, device=self.device_torch, dtype=self.torch_dtype
)
conditioning_mask[:, :, 0] = 1.0
# use conditioning mask to replace latents
latent_model_input = (
latent_model_input * (1 - conditioning_mask)
+ init_latents * conditioning_mask
)
# set video timestep
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
@@ -746,10 +807,18 @@ class LTX2Model(BaseModel):
patch_size_t=self.pipeline.transformer_temporal_patch_size,
)
if batch.audio_tensor is not None:
# use audio from the batch if available
#(1, 190, 128)
raw_audio_latents, audio_num_frames = self.encode_audio(batch)
if batch.audio_latents is not None or batch.audio_tensor is not None:
if batch.audio_latents is not None:
# we have audio latents cached
raw_audio_latents = batch.audio_latents.to(
self.device_torch, dtype=self.torch_dtype
)
else:
# we have audio waveforms to encode
# use audio from the batch if available
raw_audio_latents = self.encode_audio(batch.audio_data)
audio_num_frames = raw_audio_latents.shape[1]
# add the audio targets to the batch for loss calculation later
audio_noise = torch.randn_like(raw_audio_latents)
batch.audio_target = (audio_noise - raw_audio_latents).detach()
@@ -758,7 +827,6 @@ class LTX2Model(BaseModel):
audio_noise,
timestep,
).to(self.device_torch, dtype=self.torch_dtype)
else:
# no audio
num_mel_bins = self.pipeline.audio_vae.config.mel_bins
@@ -766,7 +834,6 @@ class LTX2Model(BaseModel):
num_channels_latents_audio = (
self.pipeline.audio_vae.config.latent_channels
)
# audio latents are (1, 126, 128), audio_num_frames = 126
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
batch_size,
num_channels_latents=num_channels_latents_audio,