mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add LTX-2 Support (#644)
* WIP, adding support for LTX2 * Training on images working * Fix loading comfy models * Handle converting and deconverting lora so it matches original format * Reworked ui to habdle ltx and propert dataset default overwriting. * Update the way lokr saves to it is more compatable with comfy * Audio loading and synchronization/resampling is working * Add audio to training. Does it work? Maybe, still testing. * Fixed fps default issue for sound * Have ui set fps for accurate audio mapping on ltx * Added audio procession options to the ui for ltx * Clean up requirements
This commit is contained in:
234
testing/test_ltx_dataloader.py
Normal file
234
testing/test_ltx_dataloader.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import time
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torchvision.io import write_video
|
||||
import subprocess
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument('dataset_folder', type=str, default='input')
|
||||
parser.add_argument('dataset_folder', type=str)
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
parser.add_argument('--num_frames', type=int, default=121)
|
||||
parser.add_argument('--output_path', type=str, default='output/dataset_test')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output_path is None:
|
||||
raise ValueError('output_path is required for this test script')
|
||||
|
||||
if args.output_path is not None:
|
||||
args.output_path = os.path.abspath(args.output_path)
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
resolution = 512
|
||||
bucket_tolerance = 64
|
||||
batch_size = 1
|
||||
frame_rate = 24
|
||||
|
||||
|
||||
## make fake sd
|
||||
class FakeSD:
|
||||
def __init__(self):
|
||||
self.use_raw_control_images = False
|
||||
|
||||
def encode_control_in_text_embeddings(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
return 32
|
||||
|
||||
dataset_config = DatasetConfig(
|
||||
dataset_path=dataset_folder,
|
||||
resolution=resolution,
|
||||
default_caption='default',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
shrink_video_to_frames=True,
|
||||
num_frames=args.num_frames,
|
||||
do_i2v=True,
|
||||
fps=frame_rate,
|
||||
do_audio=True,
|
||||
debug=True,
|
||||
audio_preserve_pitch=False,
|
||||
audio_normalize=True
|
||||
|
||||
)
|
||||
|
||||
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD())
|
||||
|
||||
|
||||
def _tensor_to_uint8_video(frames_fchw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
frames_fchw: [F, C, H, W] float/uint8
|
||||
returns: [F, H, W, C] uint8 on CPU
|
||||
"""
|
||||
x = frames_fchw.detach()
|
||||
|
||||
if x.dtype != torch.uint8:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
# Heuristic: if negatives exist, assume [-1,1] normalization; else assume [0,1]
|
||||
if torch.isfinite(x).all():
|
||||
if x.min().item() < 0.0:
|
||||
x = x * 0.5 + 0.5
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x = (x * 255.0).round().to(torch.uint8)
|
||||
else:
|
||||
x = x.to(torch.uint8)
|
||||
|
||||
# [F,C,H,W] -> [F,H,W,C]
|
||||
x = x.permute(0, 2, 3, 1).contiguous().cpu()
|
||||
return x
|
||||
|
||||
|
||||
def _mux_with_ffmpeg(video_in: str, wav_in: str, mp4_out: str):
|
||||
# Copy video stream, encode audio to AAC, align to shortest
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-hide_banner",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-i",
|
||||
video_in,
|
||||
"-i",
|
||||
wav_in,
|
||||
"-c:v",
|
||||
"copy",
|
||||
"-c:a",
|
||||
"aac",
|
||||
"-shortest",
|
||||
mp4_out,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
# run through an epoch ang check sizes
|
||||
dataloader_iterator = iter(dataloader)
|
||||
idx = 0
|
||||
for epoch in range(args.epochs):
|
||||
for batch in tqdm(dataloader):
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
frames = 1
|
||||
if len(img_batch.shape) == 5:
|
||||
frames = img_batch.shape[1]
|
||||
batch_size, frames, channels, height, width = img_batch.shape
|
||||
else:
|
||||
batch_size, channels, height, width = img_batch.shape
|
||||
|
||||
# load audio
|
||||
audio_tensor = batch.audio_tensor # all file items contatinated on the batch dimension
|
||||
audio_data = batch.audio_data # list of raw audio data per item in the batch
|
||||
|
||||
# llm save the videos here with audio and video as mp4
|
||||
fps = getattr(dataset_config, "fps", None)
|
||||
if fps is None or fps <= 0:
|
||||
fps = 1.0
|
||||
|
||||
# Ensure we can iterate items even if batch_size > 1
|
||||
for b in range(batch_size):
|
||||
# Get per-item frames as [F,C,H,W]
|
||||
if len(img_batch.shape) == 5:
|
||||
frames_fchw = img_batch[b]
|
||||
else:
|
||||
# single image: [C,H,W] -> [1,C,H,W]
|
||||
frames_fchw = img_batch[b].unsqueeze(0)
|
||||
|
||||
video_uint8 = _tensor_to_uint8_video(frames_fchw)
|
||||
out_mp4 = os.path.join(args.output_path, f"{idx:06d}_{b:02d}.mp4")
|
||||
|
||||
# Pick audio for this item (prefer audio_data list; fallback to audio_tensor)
|
||||
item_audio = None
|
||||
item_sr = None
|
||||
|
||||
if isinstance(audio_data, (list, tuple)) and len(audio_data) > b:
|
||||
ad = audio_data[b]
|
||||
if isinstance(ad, dict) and ("waveform" in ad) and ("sample_rate" in ad) and ad["waveform"] is not None:
|
||||
item_audio = ad["waveform"]
|
||||
item_sr = int(ad["sample_rate"])
|
||||
elif audio_tensor is not None and torch.is_tensor(audio_tensor):
|
||||
# audio_tensor expected [B, C, L] (or [C,L] if batch collate differs)
|
||||
if audio_tensor.dim() == 3 and audio_tensor.shape[0] > b:
|
||||
item_audio = audio_tensor[b]
|
||||
elif audio_tensor.dim() == 2 and b == 0:
|
||||
item_audio = audio_tensor
|
||||
if item_audio is not None:
|
||||
# best-effort sample rate from audio_data if present but not per-item dict
|
||||
if isinstance(audio_data, dict) and "sample_rate" in audio_data:
|
||||
try:
|
||||
item_sr = int(audio_data["sample_rate"])
|
||||
except Exception:
|
||||
item_sr = None
|
||||
|
||||
# Write mp4 (with audio if available) using ffmpeg muxing (torchvision audio muxing is unreliable)
|
||||
tmp_video = out_mp4 + ".tmp_video.mp4"
|
||||
tmp_wav = out_mp4 + ".tmp_audio.wav"
|
||||
try:
|
||||
# Always write video-only first
|
||||
write_video(tmp_video, video_uint8, fps=float(fps), video_codec="libx264")
|
||||
|
||||
if item_audio is not None and item_sr is not None and item_audio.numel() > 0:
|
||||
import torchaudio
|
||||
|
||||
wav = item_audio.detach()
|
||||
# torchaudio.save expects [channels, samples]
|
||||
if wav.dim() == 1:
|
||||
wav = wav.unsqueeze(0)
|
||||
torchaudio.save(tmp_wav, wav.cpu().to(torch.float32), int(item_sr))
|
||||
|
||||
# Mux to final mp4
|
||||
_mux_with_ffmpeg(tmp_video, tmp_wav, out_mp4)
|
||||
else:
|
||||
# No audio: just move video into place
|
||||
os.replace(tmp_video, out_mp4)
|
||||
|
||||
except Exception as e:
|
||||
# Best-effort fallback: leave a playable video-only file
|
||||
try:
|
||||
if os.path.exists(tmp_video):
|
||||
os.replace(tmp_video, out_mp4)
|
||||
else:
|
||||
write_video(out_mp4, video_uint8, fps=float(fps), video_codec="libx264")
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if hasattr(dataset_config, 'debug') and dataset_config.debug:
|
||||
print(f"Warning: failed to mux audio into mp4 for {out_mp4}: {e}")
|
||||
|
||||
finally:
|
||||
# Cleanup temps (don't leave separate wavs lying around)
|
||||
try:
|
||||
if os.path.exists(tmp_video):
|
||||
os.remove(tmp_video)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if os.path.exists(tmp_wav):
|
||||
os.remove(tmp_wav)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
idx += 1
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
print('done')
|
||||
Reference in New Issue
Block a user