Files
ai-toolkit/testing/test_ltx_dataloader.py
Jaret Burkett 5b5aadadb8 Add LTX-2 Support (#644)
* WIP, adding support for LTX2

* Training on images working

* Fix loading comfy models

* Handle converting and deconverting lora so it matches original format

* Reworked ui to habdle ltx and propert dataset default overwriting.

* Update the way lokr saves to it is more compatable with comfy

* Audio loading and synchronization/resampling is working

* Add audio to training. Does it work? Maybe, still testing.

* Fixed fps default issue for sound

* Have ui set fps for accurate audio mapping on ltx

* Added audio procession options to the ui for ltx

* Clean up requirements
2026-01-13 04:55:30 -07:00

235 lines
7.8 KiB
Python

import time
from torch.utils.data import DataLoader
import sys
import os
import argparse
from tqdm import tqdm
import torch
from torchvision.io import write_video
import subprocess
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.config_modules import DatasetConfig
parser = argparse.ArgumentParser()
# parser.add_argument('dataset_folder', type=str, default='input')
parser.add_argument('dataset_folder', type=str)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--num_frames', type=int, default=121)
parser.add_argument('--output_path', type=str, default='output/dataset_test')
args = parser.parse_args()
if args.output_path is None:
raise ValueError('output_path is required for this test script')
if args.output_path is not None:
args.output_path = os.path.abspath(args.output_path)
os.makedirs(args.output_path, exist_ok=True)
dataset_folder = args.dataset_folder
resolution = 512
bucket_tolerance = 64
batch_size = 1
frame_rate = 24
## make fake sd
class FakeSD:
def __init__(self):
self.use_raw_control_images = False
def encode_control_in_text_embeddings(self, *args, **kwargs):
return None
def get_bucket_divisibility(self):
return 32
dataset_config = DatasetConfig(
dataset_path=dataset_folder,
resolution=resolution,
default_caption='default',
buckets=True,
bucket_tolerance=bucket_tolerance,
shrink_video_to_frames=True,
num_frames=args.num_frames,
do_i2v=True,
fps=frame_rate,
do_audio=True,
debug=True,
audio_preserve_pitch=False,
audio_normalize=True
)
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD())
def _tensor_to_uint8_video(frames_fchw: torch.Tensor) -> torch.Tensor:
"""
frames_fchw: [F, C, H, W] float/uint8
returns: [F, H, W, C] uint8 on CPU
"""
x = frames_fchw.detach()
if x.dtype != torch.uint8:
x = x.to(torch.float32)
# Heuristic: if negatives exist, assume [-1,1] normalization; else assume [0,1]
if torch.isfinite(x).all():
if x.min().item() < 0.0:
x = x * 0.5 + 0.5
x = x.clamp(0.0, 1.0)
x = (x * 255.0).round().to(torch.uint8)
else:
x = x.to(torch.uint8)
# [F,C,H,W] -> [F,H,W,C]
x = x.permute(0, 2, 3, 1).contiguous().cpu()
return x
def _mux_with_ffmpeg(video_in: str, wav_in: str, mp4_out: str):
# Copy video stream, encode audio to AAC, align to shortest
subprocess.run(
[
"ffmpeg",
"-y",
"-hide_banner",
"-loglevel",
"error",
"-i",
video_in,
"-i",
wav_in,
"-c:v",
"copy",
"-c:a",
"aac",
"-shortest",
mp4_out,
],
check=True,
)
# run through an epoch ang check sizes
dataloader_iterator = iter(dataloader)
idx = 0
for epoch in range(args.epochs):
for batch in tqdm(dataloader):
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
frames = 1
if len(img_batch.shape) == 5:
frames = img_batch.shape[1]
batch_size, frames, channels, height, width = img_batch.shape
else:
batch_size, channels, height, width = img_batch.shape
# load audio
audio_tensor = batch.audio_tensor # all file items contatinated on the batch dimension
audio_data = batch.audio_data # list of raw audio data per item in the batch
# llm save the videos here with audio and video as mp4
fps = getattr(dataset_config, "fps", None)
if fps is None or fps <= 0:
fps = 1.0
# Ensure we can iterate items even if batch_size > 1
for b in range(batch_size):
# Get per-item frames as [F,C,H,W]
if len(img_batch.shape) == 5:
frames_fchw = img_batch[b]
else:
# single image: [C,H,W] -> [1,C,H,W]
frames_fchw = img_batch[b].unsqueeze(0)
video_uint8 = _tensor_to_uint8_video(frames_fchw)
out_mp4 = os.path.join(args.output_path, f"{idx:06d}_{b:02d}.mp4")
# Pick audio for this item (prefer audio_data list; fallback to audio_tensor)
item_audio = None
item_sr = None
if isinstance(audio_data, (list, tuple)) and len(audio_data) > b:
ad = audio_data[b]
if isinstance(ad, dict) and ("waveform" in ad) and ("sample_rate" in ad) and ad["waveform"] is not None:
item_audio = ad["waveform"]
item_sr = int(ad["sample_rate"])
elif audio_tensor is not None and torch.is_tensor(audio_tensor):
# audio_tensor expected [B, C, L] (or [C,L] if batch collate differs)
if audio_tensor.dim() == 3 and audio_tensor.shape[0] > b:
item_audio = audio_tensor[b]
elif audio_tensor.dim() == 2 and b == 0:
item_audio = audio_tensor
if item_audio is not None:
# best-effort sample rate from audio_data if present but not per-item dict
if isinstance(audio_data, dict) and "sample_rate" in audio_data:
try:
item_sr = int(audio_data["sample_rate"])
except Exception:
item_sr = None
# Write mp4 (with audio if available) using ffmpeg muxing (torchvision audio muxing is unreliable)
tmp_video = out_mp4 + ".tmp_video.mp4"
tmp_wav = out_mp4 + ".tmp_audio.wav"
try:
# Always write video-only first
write_video(tmp_video, video_uint8, fps=float(fps), video_codec="libx264")
if item_audio is not None and item_sr is not None and item_audio.numel() > 0:
import torchaudio
wav = item_audio.detach()
# torchaudio.save expects [channels, samples]
if wav.dim() == 1:
wav = wav.unsqueeze(0)
torchaudio.save(tmp_wav, wav.cpu().to(torch.float32), int(item_sr))
# Mux to final mp4
_mux_with_ffmpeg(tmp_video, tmp_wav, out_mp4)
else:
# No audio: just move video into place
os.replace(tmp_video, out_mp4)
except Exception as e:
# Best-effort fallback: leave a playable video-only file
try:
if os.path.exists(tmp_video):
os.replace(tmp_video, out_mp4)
else:
write_video(out_mp4, video_uint8, fps=float(fps), video_codec="libx264")
except Exception:
raise
if hasattr(dataset_config, 'debug') and dataset_config.debug:
print(f"Warning: failed to mux audio into mp4 for {out_mp4}: {e}")
finally:
# Cleanup temps (don't leave separate wavs lying around)
try:
if os.path.exists(tmp_video):
os.remove(tmp_video)
except Exception:
pass
try:
if os.path.exists(tmp_wav):
os.remove(tmp_wav)
except Exception:
pass
time.sleep(0.2)
idx += 1
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)
print('done')