Add LTX-2 Support (#644)

* WIP, adding support for LTX2

* Training on images working

* Fix loading comfy models

* Handle converting and deconverting lora so it matches original format

* Reworked ui to habdle ltx and propert dataset default overwriting.

* Update the way lokr saves to it is more compatable with comfy

* Audio loading and synchronization/resampling is working

* Add audio to training. Does it work? Maybe, still testing.

* Fixed fps default issue for sound

* Have ui set fps for accurate audio mapping on ltx

* Added audio procession options to the ui for ltx

* Clean up requirements
This commit is contained in:
Jaret Burkett
2026-01-13 04:55:30 -07:00
committed by GitHub
parent 6870ab490f
commit 5b5aadadb8
28 changed files with 2180 additions and 71 deletions

View File

@@ -16,6 +16,7 @@ from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor
from toolkit.audio.preserve_pitch import time_stretch_preserve_pitch
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution
from toolkit.config_modules import ControlTypes
@@ -467,6 +468,8 @@ class ImageProcessingDTOMixin:
if not self.dataset_config.buckets:
raise Exception('Buckets required for video processing')
do_audio = self.dataset_config.do_audio
try:
# Use OpenCV to capture video frames
cap = cv2.VideoCapture(self.path)
@@ -596,6 +599,84 @@ class ImageProcessingDTOMixin:
# Stack frames into tensor [frames, channels, height, width]
self.tensor = torch.stack(frames)
# ------------------------------
# Audio extraction + stretching
# ------------------------------
if do_audio:
# Default to "no audio" unless we successfully extract it
self.audio_data = None
self.audio_tensor = None
try:
import torchaudio
import torch.nn.functional as F
# Compute the time range of the selected frames in the *source* video
# Include the last frame by extending to the next frame boundary.
if video_fps and video_fps > 0 and len(frames_to_extract) > 0:
clip_start_frame = int(frames_to_extract[0])
clip_end_frame = int(frames_to_extract[-1])
clip_start_time = clip_start_frame / float(video_fps)
clip_end_time = (clip_end_frame + 1) / float(video_fps)
source_duration = max(0.0, clip_end_time - clip_start_time)
else:
clip_start_time = 0.0
clip_end_time = 0.0
source_duration = 0.0
# Target duration is how this sampled/stretched clip is interpreted for training
# (i.e. num_frames at the configured dataset FPS).
if hasattr(self.dataset_config, "fps") and self.dataset_config.fps and self.dataset_config.fps > 0:
target_duration = float(self.dataset_config.num_frames) / float(self.dataset_config.fps)
else:
target_duration = source_duration
waveform, sample_rate = torchaudio.load(self.path) # [channels, samples]
if self.dataset_config.audio_normalize:
peak = waveform.abs().amax() # global peak across channels
eps = 1e-9
target_peak = 0.999 # ~ -0.01 dBFS
gain = target_peak / (peak + eps)
waveform = waveform * gain
# Slice to the selected clip region (when we have a meaningful time range)
if source_duration > 0.0:
start_sample = int(round(clip_start_time * sample_rate))
end_sample = int(round(clip_end_time * sample_rate))
start_sample = max(0, min(start_sample, waveform.shape[-1]))
end_sample = max(0, min(end_sample, waveform.shape[-1]))
if end_sample > start_sample:
waveform = waveform[..., start_sample:end_sample]
else:
# No valid audio segment
waveform = None
else:
# If we can't compute a meaningful time range, treat as no-audio
waveform = None
if waveform is not None and waveform.numel() > 0:
target_samples = int(round(target_duration * sample_rate))
if target_samples > 0 and waveform.shape[-1] != target_samples:
# Time-stretch/shrink to match the video clip duration implied by dataset FPS.
if self.dataset_config.audio_preserve_pitch:
waveform = time_stretch_preserve_pitch(waveform, sample_rate, target_samples) # waveform is [C, L]
else:
# Use linear interpolation over the time axis.
wf = waveform.unsqueeze(0) # [1, C, L]
wf = F.interpolate(wf, size=target_samples, mode="linear", align_corners=False)
waveform = wf.squeeze(0) # [C, L]
self.audio_tensor = waveform
self.audio_data = {"waveform": waveform, "sample_rate": int(sample_rate)}
except Exception as e:
# Keep behavior identical for non-audio datasets; for audio datasets, just skip if missing/broken.
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
print_acc(f"Could not extract/stretch audio for {self.path}: {e}")
self.audio_data = None
self.audio_tensor = None
# Only log success in debug mode
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: