mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Do caching of latents, first frame and audio when caching latents for LTX2
This commit is contained in:
@@ -98,6 +98,41 @@ def blank_log_image_function(self, *args, **kwargs):
|
||||
return
|
||||
|
||||
|
||||
class ComboVae(torch.nn.Module):
|
||||
"""Combines video and audio VAEs for joint encoding and decoding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLLTX2Video,
|
||||
audio_vae: AutoencoderKLLTX2Audio,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
self.audio_vae = audio_vae
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.vae.device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.vae.dtype
|
||||
|
||||
def encode(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
return self.vae.encode(*args, **kwargs)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
return self.vae.decode(*args, **kwargs)
|
||||
|
||||
|
||||
class AudioProcessor(torch.nn.Module):
|
||||
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
||||
|
||||
@@ -445,7 +480,7 @@ class LTX2Model(BaseModel):
|
||||
flush()
|
||||
|
||||
# save it to the model class
|
||||
self.vae = vae
|
||||
self.vae = ComboVae(pipe.vae, pipe.audio_vae)
|
||||
self.text_encoder = text_encoder # list of text encoders
|
||||
self.tokenizer = tokenizer # list of tokenizers
|
||||
self.model = pipe.transformer
|
||||
@@ -467,10 +502,10 @@ class LTX2Model(BaseModel):
|
||||
if dtype is None:
|
||||
dtype = self.vae_torch_dtype
|
||||
|
||||
if self.vae.device == torch.device("cpu"):
|
||||
self.vae.to(device)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
if self.pipeline.vae.device == torch.device("cpu"):
|
||||
self.pipeline.vae.to(device)
|
||||
self.pipeline.vae.eval()
|
||||
self.pipeline.vae.requires_grad_(False)
|
||||
|
||||
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||
|
||||
@@ -489,7 +524,7 @@ class LTX2Model(BaseModel):
|
||||
# Stack to (B, C, T, H, W)
|
||||
images = torch.stack(norm_images)
|
||||
|
||||
latents = self.vae.encode(images).latent_dist.mode()
|
||||
latents = self.pipeline.vae.encode(images).latent_dist.mode()
|
||||
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
scaling_factor = 1.0
|
||||
@@ -534,7 +569,7 @@ class LTX2Model(BaseModel):
|
||||
):
|
||||
if self.model.device == torch.device("cpu"):
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
|
||||
# handle control image
|
||||
if gen_config.ctrl_img is not None:
|
||||
# switch to image to video pipeline
|
||||
@@ -566,12 +601,14 @@ class LTX2Model(BaseModel):
|
||||
bd = self.get_bucket_divisibility()
|
||||
gen_config.height = (gen_config.height // bd) * bd
|
||||
gen_config.width = (gen_config.width // bd) * bd
|
||||
|
||||
|
||||
# handle control image
|
||||
if gen_config.ctrl_img is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
||||
# resize the control image
|
||||
control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS)
|
||||
control_img = control_img.resize(
|
||||
(gen_config.width, gen_config.height), Image.LANCZOS
|
||||
)
|
||||
# add the control image to the extra dict
|
||||
extra["image"] = control_img
|
||||
|
||||
@@ -642,7 +679,8 @@ class LTX2Model(BaseModel):
|
||||
img = video[0]
|
||||
return img
|
||||
|
||||
def encode_audio(self, batch: "DataLoaderBatchDTO"):
|
||||
def encode_audio(self, audio_data_list):
|
||||
# audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)}
|
||||
if self.pipeline.audio_vae.device == torch.device("cpu"):
|
||||
self.pipeline.audio_vae.to(self.device_torch)
|
||||
|
||||
@@ -650,7 +688,7 @@ class LTX2Model(BaseModel):
|
||||
audio_num_frames = None
|
||||
|
||||
# do them seperatly for now
|
||||
for audio_data in batch.audio_data:
|
||||
for audio_data in audio_data_list:
|
||||
waveform = audio_data["waveform"].to(
|
||||
device=self.device_torch, dtype=torch.float32
|
||||
)
|
||||
@@ -659,26 +697,30 @@ class LTX2Model(BaseModel):
|
||||
# Add batch dimension if needed: [channels, samples] -> [batch, channels, samples]
|
||||
if waveform.dim() == 2:
|
||||
waveform = waveform.unsqueeze(0)
|
||||
|
||||
|
||||
if waveform.shape[1] == 1:
|
||||
# make sure it is stereo
|
||||
waveform = waveform.repeat(1, 2, 1)
|
||||
|
||||
|
||||
# Convert waveform to mel spectrogram using AudioProcessor
|
||||
mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate)
|
||||
mel_spectrogram = self.audio_processor.waveform_to_mel(
|
||||
waveform, waveform_sample_rate=sample_rate
|
||||
)
|
||||
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
|
||||
|
||||
# Encode mel spectrogram to latents
|
||||
latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode()
|
||||
latents = self.pipeline.audio_vae.encode(
|
||||
mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)
|
||||
).latent_dist.mode()
|
||||
|
||||
if audio_num_frames is None:
|
||||
audio_num_frames = latents.shape[2] #(latents is [B, C, T, F])
|
||||
audio_num_frames = latents.shape[2] # (latents is [B, C, T, F])
|
||||
|
||||
packed_latents = self.pipeline._pack_audio_latents(
|
||||
latents,
|
||||
# patch_size=self.pipeline.transformer.config.audio_patch_size,
|
||||
# patch_size_t=self.pipeline.transformer.config.audio_patch_size_t,
|
||||
) # [B, L, C * M]
|
||||
) # [B, L, C * M]
|
||||
if output_tensor is None:
|
||||
output_tensor = packed_latents
|
||||
else:
|
||||
@@ -688,7 +730,7 @@ class LTX2Model(BaseModel):
|
||||
latents_mean = self.pipeline.audio_vae.latents_mean
|
||||
latents_std = self.pipeline.audio_vae.latents_std
|
||||
output_tensor = (output_tensor - latents_mean) / latents_std
|
||||
return output_tensor, audio_num_frames
|
||||
return output_tensor
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
@@ -701,38 +743,57 @@ class LTX2Model(BaseModel):
|
||||
with torch.no_grad():
|
||||
if self.model.device == torch.device("cpu"):
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
|
||||
batch_size, C, latent_num_frames, latent_height, latent_width = (
|
||||
latent_model_input.shape
|
||||
)
|
||||
|
||||
|
||||
video_timestep = timestep.clone()
|
||||
|
||||
|
||||
# i2v from first frame
|
||||
if batch.dataset_config.do_i2v:
|
||||
# videos come in (bs, num_frames, channels, height, width)
|
||||
# images come in (bs, channels, height, width)
|
||||
frames = batch.tensor
|
||||
if len(frames.shape) == 4:
|
||||
first_frames = frames
|
||||
elif len(frames.shape) == 5:
|
||||
first_frames = frames[:, 0]
|
||||
# check to see if we had it cached
|
||||
if batch.first_frame_latents is not None:
|
||||
init_latents = batch.first_frame_latents.to(
|
||||
self.device_torch, dtype=self.torch_dtype
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||
# first frame doesnt have time dim, add it back
|
||||
init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype)
|
||||
# extract the first frame and encode it
|
||||
# videos come in (bs, num_frames, channels, height, width)
|
||||
# images come in (bs, channels, height, width)
|
||||
frames = batch.tensor
|
||||
if len(frames.shape) == 4:
|
||||
first_frames = frames
|
||||
elif len(frames.shape) == 5:
|
||||
first_frames = frames[:, 0]
|
||||
else:
|
||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||
# first frame doesnt have time dim, add it back
|
||||
init_latents = self.encode_images(
|
||||
first_frames, device=self.device_torch, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# expand the latents to match video frames
|
||||
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
|
||||
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
|
||||
mask_shape = (
|
||||
batch_size,
|
||||
1,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
)
|
||||
# First condition is image latents and those should be kept clean.
|
||||
conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype)
|
||||
conditioning_mask = torch.zeros(
|
||||
mask_shape, device=self.device_torch, dtype=self.torch_dtype
|
||||
)
|
||||
conditioning_mask[:, :, 0] = 1.0
|
||||
|
||||
|
||||
# use conditioning mask to replace latents
|
||||
latent_model_input = (
|
||||
latent_model_input * (1 - conditioning_mask)
|
||||
+ init_latents * conditioning_mask
|
||||
)
|
||||
|
||||
|
||||
# set video timestep
|
||||
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
@@ -746,10 +807,18 @@ class LTX2Model(BaseModel):
|
||||
patch_size_t=self.pipeline.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
if batch.audio_tensor is not None:
|
||||
# use audio from the batch if available
|
||||
#(1, 190, 128)
|
||||
raw_audio_latents, audio_num_frames = self.encode_audio(batch)
|
||||
if batch.audio_latents is not None or batch.audio_tensor is not None:
|
||||
if batch.audio_latents is not None:
|
||||
# we have audio latents cached
|
||||
raw_audio_latents = batch.audio_latents.to(
|
||||
self.device_torch, dtype=self.torch_dtype
|
||||
)
|
||||
else:
|
||||
# we have audio waveforms to encode
|
||||
# use audio from the batch if available
|
||||
raw_audio_latents = self.encode_audio(batch.audio_data)
|
||||
|
||||
audio_num_frames = raw_audio_latents.shape[1]
|
||||
# add the audio targets to the batch for loss calculation later
|
||||
audio_noise = torch.randn_like(raw_audio_latents)
|
||||
batch.audio_target = (audio_noise - raw_audio_latents).detach()
|
||||
@@ -758,7 +827,6 @@ class LTX2Model(BaseModel):
|
||||
audio_noise,
|
||||
timestep,
|
||||
).to(self.device_torch, dtype=self.torch_dtype)
|
||||
|
||||
else:
|
||||
# no audio
|
||||
num_mel_bins = self.pipeline.audio_vae.config.mel_bins
|
||||
@@ -766,7 +834,6 @@ class LTX2Model(BaseModel):
|
||||
num_channels_latents_audio = (
|
||||
self.pipeline.audio_vae.config.latent_channels
|
||||
)
|
||||
# audio latents are (1, 126, 128), audio_num_frames = 126
|
||||
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
|
||||
batch_size,
|
||||
num_channels_latents=num_channels_latents_audio,
|
||||
|
||||
Reference in New Issue
Block a user