Add LTX-2 Support (#644)

* WIP, adding support for LTX2

* Training on images working

* Fix loading comfy models

* Handle converting and deconverting lora so it matches original format

* Reworked ui to habdle ltx and propert dataset default overwriting.

* Update the way lokr saves to it is more compatable with comfy

* Audio loading and synchronization/resampling is working

* Add audio to training. Does it work? Maybe, still testing.

* Fixed fps default issue for sound

* Have ui set fps for accurate audio mapping on ltx

* Added audio procession options to the ui for ltx

* Clean up requirements
This commit is contained in:
Jaret Burkett
2026-01-13 04:55:30 -07:00
committed by GitHub
parent 6870ab490f
commit 5b5aadadb8
28 changed files with 2180 additions and 71 deletions

View File

@@ -0,0 +1,75 @@
import math
import torch
import torch.nn.functional as F
import torchaudio
def time_stretch_preserve_pitch(waveform: torch.Tensor, sample_rate: int, target_samples: int) -> torch.Tensor:
"""
waveform: [C, L] float tensor (CPU or GPU)
returns: [C, target_samples] float tensor
Pitch-preserving time stretch to match target_samples.
"""
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
waveform = waveform.to(torch.float32)
src_len = waveform.shape[-1]
if src_len == 0 or target_samples <= 0:
return waveform[..., :0]
if src_len == target_samples:
return waveform
# rate > 1.0 speeds up (shorter), rate < 1.0 slows down (longer)
rate = float(src_len) / float(target_samples)
# Use sample_rate to pick STFT params
win_seconds = 0.046
hop_seconds = 0.0115
n_fft_target = int(sample_rate * win_seconds)
n_fft = 1 << max(8, int(math.floor(math.log2(max(256, n_fft_target))))) # >=256, pow2
win_length = n_fft
hop_length = max(64, int(sample_rate * hop_seconds))
hop_length = min(hop_length, win_length // 2)
window = torch.hann_window(win_length, device=waveform.device, dtype=waveform.dtype)
stft = torch.stft(
waveform,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
return_complex=True,
) # [C, F, T] complex
# IMPORTANT: n_freq must match STFT's frequency bins (n_fft//2 + 1)
stretcher = torchaudio.transforms.TimeStretch(
n_freq=stft.shape[-2],
hop_length=hop_length,
fixed_rate=rate,
).to(waveform.device)
stft_stretched = stretcher(stft) # [C, F, T']
stretched = torch.istft(
stft_stretched,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
length=target_samples,
)
if stretched.shape[-1] > target_samples:
stretched = stretched[..., :target_samples]
elif stretched.shape[-1] < target_samples:
stretched = F.pad(stretched, (0, target_samples - stretched.shape[-1]))
return stretched

View File

@@ -207,6 +207,9 @@ class NetworkConfig:
# -1 automatically finds the largest factor
self.lokr_factor = kwargs.get('lokr_factor', -1)
# Use the old lokr format
self.old_lokr_format = kwargs.get('old_lokr_format', False)
# for multi stage models
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
@@ -672,6 +675,9 @@ class ModelConfig:
# kwargs to pass to the model
self.model_kwargs = kwargs.get("model_kwargs", {})
# model paths for models that support it
self.model_paths = kwargs.get("model_paths", {})
# allow frontend to pass arch with a color like arch:tag
# but remove the tag
if self.arch is not None:
@@ -956,7 +962,7 @@ class DatasetConfig:
# it will select a random start frame and pull the frames at the given fps
# this could have various issues with shorter videos and videos with variable fps
# I recommend trimming your videos to the desired length and using shrink_video_to_frames(default)
self.fps: int = kwargs.get('fps', 16)
self.fps: int = kwargs.get('fps', 24)
# debug the frame count and frame selection. You dont need this. It is for debugging.
self.debug: bool = kwargs.get('debug', False)
@@ -972,6 +978,9 @@ class DatasetConfig:
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable
self.do_audio: bool = kwargs.get('do_audio', False) # load audio from video files for models that support it
self.audio_preserve_pitch: bool = kwargs.get('audio_preserve_pitch', False) # preserve pitch when stretching audio to fit num_frames
self.audio_normalize: bool = kwargs.get('audio_normalize', False) # normalize audio volume levels when loading
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:

View File

@@ -123,9 +123,13 @@ class FileItemDTO(
self.is_reg = self.dataset_config.is_reg
self.prior_reg = self.dataset_config.prior_reg
self.tensor: Union[torch.Tensor, None] = None
self.audio_data = None
self.audio_tensor = None
def cleanup(self):
self.tensor = None
self.audio_data = None
self.audio_tensor = None
self.cleanup_latent()
self.cleanup_text_embedding()
self.cleanup_control()
@@ -154,6 +158,13 @@ class DataLoaderBatchDTO:
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None
self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None
self.audio_tensor: Union[torch.Tensor, None] = None
# just for holding noise and preds during training
self.audio_target: Union[torch.Tensor, None] = None
self.audio_pred: Union[torch.Tensor, None] = None
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
@@ -304,6 +315,21 @@ class DataLoaderBatchDTO:
y.text_embeds = [y.text_embeds]
prompt_embeds_list.append(y)
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
if any([x.audio_tensor is not None for x in self.file_items]):
# find one to use as a base
base_audio_tensor = None
for x in self.file_items:
if x.audio_tensor is not None:
base_audio_tensor = x.audio_tensor
break
audio_tensors = []
for x in self.file_items:
if x.audio_tensor is None:
audio_tensors.append(torch.zeros_like(base_audio_tensor))
else:
audio_tensors.append(x.audio_tensor)
self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors])
except Exception as e:
@@ -336,6 +362,10 @@ class DataLoaderBatchDTO:
del self.latents
del self.tensor
del self.control_tensor
del self.audio_tensor
del self.audio_data
del self.audio_target
del self.audio_pred
for file_item in self.file_items:
file_item.cleanup()

View File

@@ -16,6 +16,7 @@ from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor
from toolkit.audio.preserve_pitch import time_stretch_preserve_pitch
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution
from toolkit.config_modules import ControlTypes
@@ -467,6 +468,8 @@ class ImageProcessingDTOMixin:
if not self.dataset_config.buckets:
raise Exception('Buckets required for video processing')
do_audio = self.dataset_config.do_audio
try:
# Use OpenCV to capture video frames
cap = cv2.VideoCapture(self.path)
@@ -596,6 +599,84 @@ class ImageProcessingDTOMixin:
# Stack frames into tensor [frames, channels, height, width]
self.tensor = torch.stack(frames)
# ------------------------------
# Audio extraction + stretching
# ------------------------------
if do_audio:
# Default to "no audio" unless we successfully extract it
self.audio_data = None
self.audio_tensor = None
try:
import torchaudio
import torch.nn.functional as F
# Compute the time range of the selected frames in the *source* video
# Include the last frame by extending to the next frame boundary.
if video_fps and video_fps > 0 and len(frames_to_extract) > 0:
clip_start_frame = int(frames_to_extract[0])
clip_end_frame = int(frames_to_extract[-1])
clip_start_time = clip_start_frame / float(video_fps)
clip_end_time = (clip_end_frame + 1) / float(video_fps)
source_duration = max(0.0, clip_end_time - clip_start_time)
else:
clip_start_time = 0.0
clip_end_time = 0.0
source_duration = 0.0
# Target duration is how this sampled/stretched clip is interpreted for training
# (i.e. num_frames at the configured dataset FPS).
if hasattr(self.dataset_config, "fps") and self.dataset_config.fps and self.dataset_config.fps > 0:
target_duration = float(self.dataset_config.num_frames) / float(self.dataset_config.fps)
else:
target_duration = source_duration
waveform, sample_rate = torchaudio.load(self.path) # [channels, samples]
if self.dataset_config.audio_normalize:
peak = waveform.abs().amax() # global peak across channels
eps = 1e-9
target_peak = 0.999 # ~ -0.01 dBFS
gain = target_peak / (peak + eps)
waveform = waveform * gain
# Slice to the selected clip region (when we have a meaningful time range)
if source_duration > 0.0:
start_sample = int(round(clip_start_time * sample_rate))
end_sample = int(round(clip_end_time * sample_rate))
start_sample = max(0, min(start_sample, waveform.shape[-1]))
end_sample = max(0, min(end_sample, waveform.shape[-1]))
if end_sample > start_sample:
waveform = waveform[..., start_sample:end_sample]
else:
# No valid audio segment
waveform = None
else:
# If we can't compute a meaningful time range, treat as no-audio
waveform = None
if waveform is not None and waveform.numel() > 0:
target_samples = int(round(target_duration * sample_rate))
if target_samples > 0 and waveform.shape[-1] != target_samples:
# Time-stretch/shrink to match the video clip duration implied by dataset FPS.
if self.dataset_config.audio_preserve_pitch:
waveform = time_stretch_preserve_pitch(waveform, sample_rate, target_samples) # waveform is [C, L]
else:
# Use linear interpolation over the time axis.
wf = waveform.unsqueeze(0) # [1, C, L]
wf = F.interpolate(wf, size=target_samples, mode="linear", align_corners=False)
waveform = wf.squeeze(0) # [C, L]
self.audio_tensor = waveform
self.audio_data = {"waveform": waveform, "sample_rate": int(sample_rate)}
except Exception as e:
# Keep behavior identical for non-audio datasets; for audio datasets, just skip if missing/broken.
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
print_acc(f"Could not extract/stretch audio for {self.path}: {e}")
self.audio_data = None
self.audio_tensor = None
# Only log success in debug mode
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:

View File

@@ -265,11 +265,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.peft_format = peft_format
self.is_transformer = is_transformer
# use the old format for older models unless the user has specified otherwise
self.use_old_lokr_format = False
if self.network_config is not None and hasattr(self.network_config, 'old_lokr_format'):
self.use_old_lokr_format = self.network_config.old_lokr_format
# also allow a false from the model itself
if base_model is not None and not base_model.use_old_lokr_format:
self.use_old_lokr_format = False
# always do peft for flux only for now
if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer:
# don't do peft format for lokr
if self.network_type.lower() != "lokr":
# don't do peft format for lokr if using old format
if self.network_type.lower() != "lokr" or not self.use_old_lokr_format:
self.peft_format = True
if self.peft_format:

View File

@@ -185,6 +185,11 @@ class BaseModel:
self.has_multiple_control_images = False
# do not resize control images
self.use_raw_control_images = False
# defines if the model supports model paths. Only some will
self.supports_model_paths = False
# use new lokr format (default false for old models for backwards compatibility)
self.use_old_lokr_format = True
# properties for old arch for backwards compatibility
@property

View File

@@ -11,6 +11,7 @@ from toolkit.network_mixins import ToolkitModuleMixin
from typing import TYPE_CHECKING, Union, List
from optimum.quanto import QBytesTensor, QTensor
from torchao.dtypes import AffineQuantizedTensor
if TYPE_CHECKING:
@@ -284,17 +285,26 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
org_sd[weight_key] = merged_weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def get_orig_weight(self):
def get_orig_weight(self, device):
weight = self.org_module[0].weight
if weight.device != device:
weight = weight.to(device)
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
return weight.dequantize().data.detach()
elif isinstance(weight, AffineQuantizedTensor):
return weight.dequantize().data.detach()
else:
return weight.data.detach()
def get_orig_bias(self):
def get_orig_bias(self, device):
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
return self.org_module[0].bias.dequantize().data.detach()
bias = self.org_module[0].bias
if bias.device != device:
bias = bias.to(device)
if isinstance(bias, QTensor) or isinstance(bias, QBytesTensor):
return bias.dequantize().data.detach()
elif isinstance(bias, AffineQuantizedTensor):
return bias.dequantize().data.detach()
else:
return self.org_module[0].bias.data.detach()
return None
@@ -305,7 +315,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
orig_dtype = x.dtype
orig_weight = self.get_orig_weight()
orig_weight = self.get_orig_weight(x.device)
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
multiplier = self.network_ref().torch_multiplier
@@ -319,7 +329,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
orig_weight
+ lokr_weight * multiplier
)
bias = self.get_orig_bias()
bias = self.get_orig_bias(x.device)
if bias is not None:
bias = bias.to(weight.device, dtype=weight.dtype)
output = self.op(

View File

@@ -546,7 +546,8 @@ class ToolkitNetworkMixin:
new_save_dict = {}
for key, value in save_dict.items():
if key.endswith('.alpha'):
# lokr needs alpha
if key.endswith('.alpha') and self.network_type.lower() != "lokr":
continue
new_key = key
new_key = new_key.replace('lora_down', 'lora_A')
@@ -558,7 +559,7 @@ class ToolkitNetworkMixin:
save_dict = new_save_dict
if self.network_type.lower() == "lokr":
if self.network_type.lower() == "lokr" and self.use_old_lokr_format:
new_save_dict = {}
for key, value in save_dict.items():
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
@@ -632,7 +633,7 @@ class ToolkitNetworkMixin:
# lora_down = lora_A
# lora_up = lora_B
# no alpha
if load_key.endswith('.alpha'):
if load_key.endswith('.alpha') and self.network_type.lower() != "lokr":
continue
load_key = load_key.replace('lora_A', 'lora_down')
load_key = load_key.replace('lora_B', 'lora_up')
@@ -640,6 +641,13 @@ class ToolkitNetworkMixin:
load_key = load_key.replace('.', '$$')
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
# patch lokr, not sure why we need to but whatever
if self.network_type.lower() == "lokr":
load_key = load_key.replace('$$lokr_w1', '.lokr_w1')
load_key = load_key.replace('$$lokr_w2', '.lokr_w2')
if load_key.endswith('$$alpha'):
load_key = load_key[:-7] + '.alpha'
if self.network_type.lower() == "lokr":
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1

View File

@@ -223,6 +223,11 @@ class StableDiffusion:
self.has_multiple_control_images = False
# do not resize control images
self.use_raw_control_images = False
# defines if the model supports model paths. Only some will
self.supports_model_paths = False
# use new lokr format (default false for old models for backwards compatibility)
self.use_old_lokr_format = True
# properties for old arch for backwards compatibility
@property