mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
* WIP, adding support for LTX2 * Training on images working * Fix loading comfy models * Handle converting and deconverting lora so it matches original format * Reworked ui to habdle ltx and propert dataset default overwriting. * Update the way lokr saves to it is more compatable with comfy * Audio loading and synchronization/resampling is working * Add audio to training. Does it work? Maybe, still testing. * Fixed fps default issue for sound * Have ui set fps for accurate audio mapping on ltx * Added audio procession options to the ui for ltx * Clean up requirements
235 lines
7.8 KiB
Python
235 lines
7.8 KiB
Python
import time
|
|
|
|
from torch.utils.data import DataLoader
|
|
import sys
|
|
import os
|
|
import argparse
|
|
from tqdm import tqdm
|
|
import torch
|
|
from torchvision.io import write_video
|
|
import subprocess
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
|
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
|
|
from toolkit.config_modules import DatasetConfig
|
|
|
|
parser = argparse.ArgumentParser()
|
|
# parser.add_argument('dataset_folder', type=str, default='input')
|
|
parser.add_argument('dataset_folder', type=str)
|
|
parser.add_argument('--epochs', type=int, default=1)
|
|
parser.add_argument('--num_frames', type=int, default=121)
|
|
parser.add_argument('--output_path', type=str, default='output/dataset_test')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.output_path is None:
|
|
raise ValueError('output_path is required for this test script')
|
|
|
|
if args.output_path is not None:
|
|
args.output_path = os.path.abspath(args.output_path)
|
|
os.makedirs(args.output_path, exist_ok=True)
|
|
|
|
dataset_folder = args.dataset_folder
|
|
resolution = 512
|
|
bucket_tolerance = 64
|
|
batch_size = 1
|
|
frame_rate = 24
|
|
|
|
|
|
## make fake sd
|
|
class FakeSD:
|
|
def __init__(self):
|
|
self.use_raw_control_images = False
|
|
|
|
def encode_control_in_text_embeddings(self, *args, **kwargs):
|
|
return None
|
|
|
|
def get_bucket_divisibility(self):
|
|
return 32
|
|
|
|
dataset_config = DatasetConfig(
|
|
dataset_path=dataset_folder,
|
|
resolution=resolution,
|
|
default_caption='default',
|
|
buckets=True,
|
|
bucket_tolerance=bucket_tolerance,
|
|
shrink_video_to_frames=True,
|
|
num_frames=args.num_frames,
|
|
do_i2v=True,
|
|
fps=frame_rate,
|
|
do_audio=True,
|
|
debug=True,
|
|
audio_preserve_pitch=False,
|
|
audio_normalize=True
|
|
|
|
)
|
|
|
|
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD())
|
|
|
|
|
|
def _tensor_to_uint8_video(frames_fchw: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
frames_fchw: [F, C, H, W] float/uint8
|
|
returns: [F, H, W, C] uint8 on CPU
|
|
"""
|
|
x = frames_fchw.detach()
|
|
|
|
if x.dtype != torch.uint8:
|
|
x = x.to(torch.float32)
|
|
|
|
# Heuristic: if negatives exist, assume [-1,1] normalization; else assume [0,1]
|
|
if torch.isfinite(x).all():
|
|
if x.min().item() < 0.0:
|
|
x = x * 0.5 + 0.5
|
|
x = x.clamp(0.0, 1.0)
|
|
x = (x * 255.0).round().to(torch.uint8)
|
|
else:
|
|
x = x.to(torch.uint8)
|
|
|
|
# [F,C,H,W] -> [F,H,W,C]
|
|
x = x.permute(0, 2, 3, 1).contiguous().cpu()
|
|
return x
|
|
|
|
|
|
def _mux_with_ffmpeg(video_in: str, wav_in: str, mp4_out: str):
|
|
# Copy video stream, encode audio to AAC, align to shortest
|
|
subprocess.run(
|
|
[
|
|
"ffmpeg",
|
|
"-y",
|
|
"-hide_banner",
|
|
"-loglevel",
|
|
"error",
|
|
"-i",
|
|
video_in,
|
|
"-i",
|
|
wav_in,
|
|
"-c:v",
|
|
"copy",
|
|
"-c:a",
|
|
"aac",
|
|
"-shortest",
|
|
mp4_out,
|
|
],
|
|
check=True,
|
|
)
|
|
|
|
|
|
# run through an epoch ang check sizes
|
|
dataloader_iterator = iter(dataloader)
|
|
idx = 0
|
|
for epoch in range(args.epochs):
|
|
for batch in tqdm(dataloader):
|
|
batch: 'DataLoaderBatchDTO'
|
|
img_batch = batch.tensor
|
|
frames = 1
|
|
if len(img_batch.shape) == 5:
|
|
frames = img_batch.shape[1]
|
|
batch_size, frames, channels, height, width = img_batch.shape
|
|
else:
|
|
batch_size, channels, height, width = img_batch.shape
|
|
|
|
# load audio
|
|
audio_tensor = batch.audio_tensor # all file items contatinated on the batch dimension
|
|
audio_data = batch.audio_data # list of raw audio data per item in the batch
|
|
|
|
# llm save the videos here with audio and video as mp4
|
|
fps = getattr(dataset_config, "fps", None)
|
|
if fps is None or fps <= 0:
|
|
fps = 1.0
|
|
|
|
# Ensure we can iterate items even if batch_size > 1
|
|
for b in range(batch_size):
|
|
# Get per-item frames as [F,C,H,W]
|
|
if len(img_batch.shape) == 5:
|
|
frames_fchw = img_batch[b]
|
|
else:
|
|
# single image: [C,H,W] -> [1,C,H,W]
|
|
frames_fchw = img_batch[b].unsqueeze(0)
|
|
|
|
video_uint8 = _tensor_to_uint8_video(frames_fchw)
|
|
out_mp4 = os.path.join(args.output_path, f"{idx:06d}_{b:02d}.mp4")
|
|
|
|
# Pick audio for this item (prefer audio_data list; fallback to audio_tensor)
|
|
item_audio = None
|
|
item_sr = None
|
|
|
|
if isinstance(audio_data, (list, tuple)) and len(audio_data) > b:
|
|
ad = audio_data[b]
|
|
if isinstance(ad, dict) and ("waveform" in ad) and ("sample_rate" in ad) and ad["waveform"] is not None:
|
|
item_audio = ad["waveform"]
|
|
item_sr = int(ad["sample_rate"])
|
|
elif audio_tensor is not None and torch.is_tensor(audio_tensor):
|
|
# audio_tensor expected [B, C, L] (or [C,L] if batch collate differs)
|
|
if audio_tensor.dim() == 3 and audio_tensor.shape[0] > b:
|
|
item_audio = audio_tensor[b]
|
|
elif audio_tensor.dim() == 2 and b == 0:
|
|
item_audio = audio_tensor
|
|
if item_audio is not None:
|
|
# best-effort sample rate from audio_data if present but not per-item dict
|
|
if isinstance(audio_data, dict) and "sample_rate" in audio_data:
|
|
try:
|
|
item_sr = int(audio_data["sample_rate"])
|
|
except Exception:
|
|
item_sr = None
|
|
|
|
# Write mp4 (with audio if available) using ffmpeg muxing (torchvision audio muxing is unreliable)
|
|
tmp_video = out_mp4 + ".tmp_video.mp4"
|
|
tmp_wav = out_mp4 + ".tmp_audio.wav"
|
|
try:
|
|
# Always write video-only first
|
|
write_video(tmp_video, video_uint8, fps=float(fps), video_codec="libx264")
|
|
|
|
if item_audio is not None and item_sr is not None and item_audio.numel() > 0:
|
|
import torchaudio
|
|
|
|
wav = item_audio.detach()
|
|
# torchaudio.save expects [channels, samples]
|
|
if wav.dim() == 1:
|
|
wav = wav.unsqueeze(0)
|
|
torchaudio.save(tmp_wav, wav.cpu().to(torch.float32), int(item_sr))
|
|
|
|
# Mux to final mp4
|
|
_mux_with_ffmpeg(tmp_video, tmp_wav, out_mp4)
|
|
else:
|
|
# No audio: just move video into place
|
|
os.replace(tmp_video, out_mp4)
|
|
|
|
except Exception as e:
|
|
# Best-effort fallback: leave a playable video-only file
|
|
try:
|
|
if os.path.exists(tmp_video):
|
|
os.replace(tmp_video, out_mp4)
|
|
else:
|
|
write_video(out_mp4, video_uint8, fps=float(fps), video_codec="libx264")
|
|
except Exception:
|
|
raise
|
|
|
|
if hasattr(dataset_config, 'debug') and dataset_config.debug:
|
|
print(f"Warning: failed to mux audio into mp4 for {out_mp4}: {e}")
|
|
|
|
finally:
|
|
# Cleanup temps (don't leave separate wavs lying around)
|
|
try:
|
|
if os.path.exists(tmp_video):
|
|
os.remove(tmp_video)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
if os.path.exists(tmp_wav):
|
|
os.remove(tmp_wav)
|
|
except Exception:
|
|
pass
|
|
|
|
time.sleep(0.2)
|
|
|
|
idx += 1
|
|
# if not last epoch
|
|
if epoch < args.epochs - 1:
|
|
trigger_dataloader_setup_epoch(dataloader)
|
|
|
|
print('done')
|