mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Add LTX-2 Support (#644)
* WIP, adding support for LTX2 * Training on images working * Fix loading comfy models * Handle converting and deconverting lora so it matches original format * Reworked ui to habdle ltx and propert dataset default overwriting. * Update the way lokr saves to it is more compatable with comfy * Audio loading and synchronization/resampling is working * Add audio to training. Does it work? Maybe, still testing. * Fixed fps default issue for sound * Have ui set fps for accurate audio mapping on ltx * Added audio procession options to the ui for ltx * Clean up requirements
This commit is contained in:
@@ -16,6 +16,7 @@ from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor
|
||||
|
||||
from toolkit.audio.preserve_pitch import time_stretch_preserve_pitch
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||
from toolkit.config_modules import ControlTypes
|
||||
@@ -467,6 +468,8 @@ class ImageProcessingDTOMixin:
|
||||
if not self.dataset_config.buckets:
|
||||
raise Exception('Buckets required for video processing')
|
||||
|
||||
do_audio = self.dataset_config.do_audio
|
||||
|
||||
try:
|
||||
# Use OpenCV to capture video frames
|
||||
cap = cv2.VideoCapture(self.path)
|
||||
@@ -596,6 +599,84 @@ class ImageProcessingDTOMixin:
|
||||
|
||||
# Stack frames into tensor [frames, channels, height, width]
|
||||
self.tensor = torch.stack(frames)
|
||||
|
||||
# ------------------------------
|
||||
# Audio extraction + stretching
|
||||
# ------------------------------
|
||||
if do_audio:
|
||||
# Default to "no audio" unless we successfully extract it
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Compute the time range of the selected frames in the *source* video
|
||||
# Include the last frame by extending to the next frame boundary.
|
||||
if video_fps and video_fps > 0 and len(frames_to_extract) > 0:
|
||||
clip_start_frame = int(frames_to_extract[0])
|
||||
clip_end_frame = int(frames_to_extract[-1])
|
||||
clip_start_time = clip_start_frame / float(video_fps)
|
||||
clip_end_time = (clip_end_frame + 1) / float(video_fps)
|
||||
source_duration = max(0.0, clip_end_time - clip_start_time)
|
||||
else:
|
||||
clip_start_time = 0.0
|
||||
clip_end_time = 0.0
|
||||
source_duration = 0.0
|
||||
|
||||
# Target duration is how this sampled/stretched clip is interpreted for training
|
||||
# (i.e. num_frames at the configured dataset FPS).
|
||||
if hasattr(self.dataset_config, "fps") and self.dataset_config.fps and self.dataset_config.fps > 0:
|
||||
target_duration = float(self.dataset_config.num_frames) / float(self.dataset_config.fps)
|
||||
else:
|
||||
target_duration = source_duration
|
||||
|
||||
waveform, sample_rate = torchaudio.load(self.path) # [channels, samples]
|
||||
|
||||
if self.dataset_config.audio_normalize:
|
||||
peak = waveform.abs().amax() # global peak across channels
|
||||
eps = 1e-9
|
||||
target_peak = 0.999 # ~ -0.01 dBFS
|
||||
gain = target_peak / (peak + eps)
|
||||
waveform = waveform * gain
|
||||
|
||||
# Slice to the selected clip region (when we have a meaningful time range)
|
||||
if source_duration > 0.0:
|
||||
start_sample = int(round(clip_start_time * sample_rate))
|
||||
end_sample = int(round(clip_end_time * sample_rate))
|
||||
start_sample = max(0, min(start_sample, waveform.shape[-1]))
|
||||
end_sample = max(0, min(end_sample, waveform.shape[-1]))
|
||||
if end_sample > start_sample:
|
||||
waveform = waveform[..., start_sample:end_sample]
|
||||
else:
|
||||
# No valid audio segment
|
||||
waveform = None
|
||||
else:
|
||||
# If we can't compute a meaningful time range, treat as no-audio
|
||||
waveform = None
|
||||
|
||||
if waveform is not None and waveform.numel() > 0:
|
||||
target_samples = int(round(target_duration * sample_rate))
|
||||
if target_samples > 0 and waveform.shape[-1] != target_samples:
|
||||
# Time-stretch/shrink to match the video clip duration implied by dataset FPS.
|
||||
if self.dataset_config.audio_preserve_pitch:
|
||||
waveform = time_stretch_preserve_pitch(waveform, sample_rate, target_samples) # waveform is [C, L]
|
||||
else:
|
||||
# Use linear interpolation over the time axis.
|
||||
wf = waveform.unsqueeze(0) # [1, C, L]
|
||||
wf = F.interpolate(wf, size=target_samples, mode="linear", align_corners=False)
|
||||
waveform = wf.squeeze(0) # [C, L]
|
||||
|
||||
self.audio_tensor = waveform
|
||||
self.audio_data = {"waveform": waveform, "sample_rate": int(sample_rate)}
|
||||
|
||||
except Exception as e:
|
||||
# Keep behavior identical for non-audio datasets; for audio datasets, just skip if missing/broken.
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
print_acc(f"Could not extract/stretch audio for {self.path}: {e}")
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
|
||||
# Only log success in debug mode
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
|
||||
Reference in New Issue
Block a user