mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Add LTX-2 Support (#644)
* WIP, adding support for LTX2 * Training on images working * Fix loading comfy models * Handle converting and deconverting lora so it matches original format * Reworked ui to habdle ltx and propert dataset default overwriting. * Update the way lokr saves to it is more compatable with comfy * Audio loading and synchronization/resampling is working * Add audio to training. Does it work? Maybe, still testing. * Fixed fps default issue for sound * Have ui set fps for accurate audio mapping on ltx * Added audio procession options to the ui for ltx * Clean up requirements
This commit is contained in:
75
toolkit/audio/preserve_pitch.py
Normal file
75
toolkit/audio/preserve_pitch.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
def time_stretch_preserve_pitch(waveform: torch.Tensor, sample_rate: int, target_samples: int) -> torch.Tensor:
|
||||
"""
|
||||
waveform: [C, L] float tensor (CPU or GPU)
|
||||
returns: [C, target_samples] float tensor
|
||||
Pitch-preserving time stretch to match target_samples.
|
||||
"""
|
||||
|
||||
|
||||
if waveform.dim() == 1:
|
||||
waveform = waveform.unsqueeze(0)
|
||||
|
||||
waveform = waveform.to(torch.float32)
|
||||
|
||||
src_len = waveform.shape[-1]
|
||||
if src_len == 0 or target_samples <= 0:
|
||||
return waveform[..., :0]
|
||||
|
||||
if src_len == target_samples:
|
||||
return waveform
|
||||
|
||||
# rate > 1.0 speeds up (shorter), rate < 1.0 slows down (longer)
|
||||
rate = float(src_len) / float(target_samples)
|
||||
|
||||
# Use sample_rate to pick STFT params
|
||||
win_seconds = 0.046
|
||||
hop_seconds = 0.0115
|
||||
|
||||
n_fft_target = int(sample_rate * win_seconds)
|
||||
n_fft = 1 << max(8, int(math.floor(math.log2(max(256, n_fft_target))))) # >=256, pow2
|
||||
win_length = n_fft
|
||||
hop_length = max(64, int(sample_rate * hop_seconds))
|
||||
hop_length = min(hop_length, win_length // 2)
|
||||
|
||||
window = torch.hann_window(win_length, device=waveform.device, dtype=waveform.dtype)
|
||||
|
||||
stft = torch.stft(
|
||||
waveform,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=True,
|
||||
return_complex=True,
|
||||
) # [C, F, T] complex
|
||||
|
||||
# IMPORTANT: n_freq must match STFT's frequency bins (n_fft//2 + 1)
|
||||
stretcher = torchaudio.transforms.TimeStretch(
|
||||
n_freq=stft.shape[-2],
|
||||
hop_length=hop_length,
|
||||
fixed_rate=rate,
|
||||
).to(waveform.device)
|
||||
|
||||
stft_stretched = stretcher(stft) # [C, F, T']
|
||||
|
||||
stretched = torch.istft(
|
||||
stft_stretched,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=True,
|
||||
length=target_samples,
|
||||
)
|
||||
|
||||
if stretched.shape[-1] > target_samples:
|
||||
stretched = stretched[..., :target_samples]
|
||||
elif stretched.shape[-1] < target_samples:
|
||||
stretched = F.pad(stretched, (0, target_samples - stretched.shape[-1]))
|
||||
|
||||
return stretched
|
||||
@@ -207,6 +207,9 @@ class NetworkConfig:
|
||||
# -1 automatically finds the largest factor
|
||||
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||
|
||||
# Use the old lokr format
|
||||
self.old_lokr_format = kwargs.get('old_lokr_format', False)
|
||||
|
||||
# for multi stage models
|
||||
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
|
||||
|
||||
@@ -672,6 +675,9 @@ class ModelConfig:
|
||||
# kwargs to pass to the model
|
||||
self.model_kwargs = kwargs.get("model_kwargs", {})
|
||||
|
||||
# model paths for models that support it
|
||||
self.model_paths = kwargs.get("model_paths", {})
|
||||
|
||||
# allow frontend to pass arch with a color like arch:tag
|
||||
# but remove the tag
|
||||
if self.arch is not None:
|
||||
@@ -956,7 +962,7 @@ class DatasetConfig:
|
||||
# it will select a random start frame and pull the frames at the given fps
|
||||
# this could have various issues with shorter videos and videos with variable fps
|
||||
# I recommend trimming your videos to the desired length and using shrink_video_to_frames(default)
|
||||
self.fps: int = kwargs.get('fps', 16)
|
||||
self.fps: int = kwargs.get('fps', 24)
|
||||
|
||||
# debug the frame count and frame selection. You dont need this. It is for debugging.
|
||||
self.debug: bool = kwargs.get('debug', False)
|
||||
@@ -972,6 +978,9 @@ class DatasetConfig:
|
||||
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
|
||||
|
||||
self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable
|
||||
self.do_audio: bool = kwargs.get('do_audio', False) # load audio from video files for models that support it
|
||||
self.audio_preserve_pitch: bool = kwargs.get('audio_preserve_pitch', False) # preserve pitch when stretching audio to fit num_frames
|
||||
self.audio_normalize: bool = kwargs.get('audio_normalize', False) # normalize audio volume levels when loading
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
|
||||
@@ -123,9 +123,13 @@ class FileItemDTO(
|
||||
self.is_reg = self.dataset_config.is_reg
|
||||
self.prior_reg = self.dataset_config.prior_reg
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_text_embedding()
|
||||
self.cleanup_control()
|
||||
@@ -154,6 +158,13 @@ class DataLoaderBatchDTO:
|
||||
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None
|
||||
self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None
|
||||
self.audio_tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
# just for holding noise and preds during training
|
||||
self.audio_target: Union[torch.Tensor, None] = None
|
||||
self.audio_pred: Union[torch.Tensor, None] = None
|
||||
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
@@ -304,6 +315,21 @@ class DataLoaderBatchDTO:
|
||||
y.text_embeds = [y.text_embeds]
|
||||
prompt_embeds_list.append(y)
|
||||
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
|
||||
|
||||
if any([x.audio_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_audio_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.audio_tensor is not None:
|
||||
base_audio_tensor = x.audio_tensor
|
||||
break
|
||||
audio_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.audio_tensor is None:
|
||||
audio_tensors.append(torch.zeros_like(base_audio_tensor))
|
||||
else:
|
||||
audio_tensors.append(x.audio_tensor)
|
||||
self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors])
|
||||
|
||||
|
||||
except Exception as e:
|
||||
@@ -336,6 +362,10 @@ class DataLoaderBatchDTO:
|
||||
del self.latents
|
||||
del self.tensor
|
||||
del self.control_tensor
|
||||
del self.audio_tensor
|
||||
del self.audio_data
|
||||
del self.audio_target
|
||||
del self.audio_pred
|
||||
for file_item in self.file_items:
|
||||
file_item.cleanup()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor
|
||||
|
||||
from toolkit.audio.preserve_pitch import time_stretch_preserve_pitch
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||
from toolkit.config_modules import ControlTypes
|
||||
@@ -467,6 +468,8 @@ class ImageProcessingDTOMixin:
|
||||
if not self.dataset_config.buckets:
|
||||
raise Exception('Buckets required for video processing')
|
||||
|
||||
do_audio = self.dataset_config.do_audio
|
||||
|
||||
try:
|
||||
# Use OpenCV to capture video frames
|
||||
cap = cv2.VideoCapture(self.path)
|
||||
@@ -596,6 +599,84 @@ class ImageProcessingDTOMixin:
|
||||
|
||||
# Stack frames into tensor [frames, channels, height, width]
|
||||
self.tensor = torch.stack(frames)
|
||||
|
||||
# ------------------------------
|
||||
# Audio extraction + stretching
|
||||
# ------------------------------
|
||||
if do_audio:
|
||||
# Default to "no audio" unless we successfully extract it
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Compute the time range of the selected frames in the *source* video
|
||||
# Include the last frame by extending to the next frame boundary.
|
||||
if video_fps and video_fps > 0 and len(frames_to_extract) > 0:
|
||||
clip_start_frame = int(frames_to_extract[0])
|
||||
clip_end_frame = int(frames_to_extract[-1])
|
||||
clip_start_time = clip_start_frame / float(video_fps)
|
||||
clip_end_time = (clip_end_frame + 1) / float(video_fps)
|
||||
source_duration = max(0.0, clip_end_time - clip_start_time)
|
||||
else:
|
||||
clip_start_time = 0.0
|
||||
clip_end_time = 0.0
|
||||
source_duration = 0.0
|
||||
|
||||
# Target duration is how this sampled/stretched clip is interpreted for training
|
||||
# (i.e. num_frames at the configured dataset FPS).
|
||||
if hasattr(self.dataset_config, "fps") and self.dataset_config.fps and self.dataset_config.fps > 0:
|
||||
target_duration = float(self.dataset_config.num_frames) / float(self.dataset_config.fps)
|
||||
else:
|
||||
target_duration = source_duration
|
||||
|
||||
waveform, sample_rate = torchaudio.load(self.path) # [channels, samples]
|
||||
|
||||
if self.dataset_config.audio_normalize:
|
||||
peak = waveform.abs().amax() # global peak across channels
|
||||
eps = 1e-9
|
||||
target_peak = 0.999 # ~ -0.01 dBFS
|
||||
gain = target_peak / (peak + eps)
|
||||
waveform = waveform * gain
|
||||
|
||||
# Slice to the selected clip region (when we have a meaningful time range)
|
||||
if source_duration > 0.0:
|
||||
start_sample = int(round(clip_start_time * sample_rate))
|
||||
end_sample = int(round(clip_end_time * sample_rate))
|
||||
start_sample = max(0, min(start_sample, waveform.shape[-1]))
|
||||
end_sample = max(0, min(end_sample, waveform.shape[-1]))
|
||||
if end_sample > start_sample:
|
||||
waveform = waveform[..., start_sample:end_sample]
|
||||
else:
|
||||
# No valid audio segment
|
||||
waveform = None
|
||||
else:
|
||||
# If we can't compute a meaningful time range, treat as no-audio
|
||||
waveform = None
|
||||
|
||||
if waveform is not None and waveform.numel() > 0:
|
||||
target_samples = int(round(target_duration * sample_rate))
|
||||
if target_samples > 0 and waveform.shape[-1] != target_samples:
|
||||
# Time-stretch/shrink to match the video clip duration implied by dataset FPS.
|
||||
if self.dataset_config.audio_preserve_pitch:
|
||||
waveform = time_stretch_preserve_pitch(waveform, sample_rate, target_samples) # waveform is [C, L]
|
||||
else:
|
||||
# Use linear interpolation over the time axis.
|
||||
wf = waveform.unsqueeze(0) # [1, C, L]
|
||||
wf = F.interpolate(wf, size=target_samples, mode="linear", align_corners=False)
|
||||
waveform = wf.squeeze(0) # [C, L]
|
||||
|
||||
self.audio_tensor = waveform
|
||||
self.audio_data = {"waveform": waveform, "sample_rate": int(sample_rate)}
|
||||
|
||||
except Exception as e:
|
||||
# Keep behavior identical for non-audio datasets; for audio datasets, just skip if missing/broken.
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
print_acc(f"Could not extract/stretch audio for {self.path}: {e}")
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
|
||||
# Only log success in debug mode
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
|
||||
@@ -265,11 +265,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.peft_format = peft_format
|
||||
self.is_transformer = is_transformer
|
||||
|
||||
# use the old format for older models unless the user has specified otherwise
|
||||
self.use_old_lokr_format = False
|
||||
if self.network_config is not None and hasattr(self.network_config, 'old_lokr_format'):
|
||||
self.use_old_lokr_format = self.network_config.old_lokr_format
|
||||
# also allow a false from the model itself
|
||||
if base_model is not None and not base_model.use_old_lokr_format:
|
||||
self.use_old_lokr_format = False
|
||||
|
||||
# always do peft for flux only for now
|
||||
if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer:
|
||||
# don't do peft format for lokr
|
||||
if self.network_type.lower() != "lokr":
|
||||
# don't do peft format for lokr if using old format
|
||||
if self.network_type.lower() != "lokr" or not self.use_old_lokr_format:
|
||||
self.peft_format = True
|
||||
|
||||
if self.peft_format:
|
||||
|
||||
@@ -185,6 +185,11 @@ class BaseModel:
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
# defines if the model supports model paths. Only some will
|
||||
self.supports_model_paths = False
|
||||
|
||||
# use new lokr format (default false for old models for backwards compatibility)
|
||||
self.use_old_lokr_format = True
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
@@ -11,6 +11,7 @@ from toolkit.network_mixins import ToolkitModuleMixin
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
|
||||
from optimum.quanto import QBytesTensor, QTensor
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -284,17 +285,26 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
org_sd[weight_key] = merged_weight.to(orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def get_orig_weight(self):
|
||||
def get_orig_weight(self, device):
|
||||
weight = self.org_module[0].weight
|
||||
if weight.device != device:
|
||||
weight = weight.to(device)
|
||||
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
|
||||
return weight.dequantize().data.detach()
|
||||
elif isinstance(weight, AffineQuantizedTensor):
|
||||
return weight.dequantize().data.detach()
|
||||
else:
|
||||
return weight.data.detach()
|
||||
|
||||
def get_orig_bias(self):
|
||||
def get_orig_bias(self, device):
|
||||
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
|
||||
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
|
||||
return self.org_module[0].bias.dequantize().data.detach()
|
||||
bias = self.org_module[0].bias
|
||||
if bias.device != device:
|
||||
bias = bias.to(device)
|
||||
if isinstance(bias, QTensor) or isinstance(bias, QBytesTensor):
|
||||
return bias.dequantize().data.detach()
|
||||
elif isinstance(bias, AffineQuantizedTensor):
|
||||
return bias.dequantize().data.detach()
|
||||
else:
|
||||
return self.org_module[0].bias.data.detach()
|
||||
return None
|
||||
@@ -305,7 +315,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
|
||||
orig_dtype = x.dtype
|
||||
|
||||
orig_weight = self.get_orig_weight()
|
||||
orig_weight = self.get_orig_weight(x.device)
|
||||
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
|
||||
@@ -319,7 +329,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
orig_weight
|
||||
+ lokr_weight * multiplier
|
||||
)
|
||||
bias = self.get_orig_bias()
|
||||
bias = self.get_orig_bias(x.device)
|
||||
if bias is not None:
|
||||
bias = bias.to(weight.device, dtype=weight.dtype)
|
||||
output = self.op(
|
||||
|
||||
@@ -546,7 +546,8 @@ class ToolkitNetworkMixin:
|
||||
|
||||
new_save_dict = {}
|
||||
for key, value in save_dict.items():
|
||||
if key.endswith('.alpha'):
|
||||
# lokr needs alpha
|
||||
if key.endswith('.alpha') and self.network_type.lower() != "lokr":
|
||||
continue
|
||||
new_key = key
|
||||
new_key = new_key.replace('lora_down', 'lora_A')
|
||||
@@ -558,7 +559,7 @@ class ToolkitNetworkMixin:
|
||||
save_dict = new_save_dict
|
||||
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
if self.network_type.lower() == "lokr" and self.use_old_lokr_format:
|
||||
new_save_dict = {}
|
||||
for key, value in save_dict.items():
|
||||
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||
@@ -632,7 +633,7 @@ class ToolkitNetworkMixin:
|
||||
# lora_down = lora_A
|
||||
# lora_up = lora_B
|
||||
# no alpha
|
||||
if load_key.endswith('.alpha'):
|
||||
if load_key.endswith('.alpha') and self.network_type.lower() != "lokr":
|
||||
continue
|
||||
load_key = load_key.replace('lora_A', 'lora_down')
|
||||
load_key = load_key.replace('lora_B', 'lora_up')
|
||||
@@ -640,6 +641,13 @@ class ToolkitNetworkMixin:
|
||||
load_key = load_key.replace('.', '$$')
|
||||
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
|
||||
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
|
||||
|
||||
# patch lokr, not sure why we need to but whatever
|
||||
if self.network_type.lower() == "lokr":
|
||||
load_key = load_key.replace('$$lokr_w1', '.lokr_w1')
|
||||
load_key = load_key.replace('$$lokr_w2', '.lokr_w2')
|
||||
if load_key.endswith('$$alpha'):
|
||||
load_key = load_key[:-7] + '.alpha'
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||
|
||||
@@ -223,6 +223,11 @@ class StableDiffusion:
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
# defines if the model supports model paths. Only some will
|
||||
self.supports_model_paths = False
|
||||
|
||||
# use new lokr format (default false for old models for backwards compatibility)
|
||||
self.use_old_lokr_format = True
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user