From 5b5aadadb8310a0ed435fd142e5571c7cc0e0385 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 13 Jan 2026 04:55:30 -0700 Subject: [PATCH] Add LTX-2 Support (#644) * WIP, adding support for LTX2 * Training on images working * Fix loading comfy models * Handle converting and deconverting lora so it matches original format * Reworked ui to habdle ltx and propert dataset default overwriting. * Update the way lokr saves to it is more compatable with comfy * Audio loading and synchronization/resampling is working * Add audio to training. Does it work? Maybe, still testing. * Fixed fps default issue for sound * Have ui set fps for accurate audio mapping on ltx * Added audio procession options to the ui for ltx * Clean up requirements --- .../diffusion_models/__init__.py | 2 + .../diffusion_models/ltx2/__init__.py | 1 + .../ltx2/convert_ltx2_to_diffusers.py | 648 +++++++++++++ .../diffusion_models/ltx2/ltx2.py | 852 ++++++++++++++++++ extensions_built_in/sd_trainer/SDTrainer.py | 5 + requirements.txt | 6 +- testing/test_ltx_dataloader.py | 234 +++++ toolkit/audio/preserve_pitch.py | 75 ++ toolkit/config_modules.py | 11 +- toolkit/data_transfer_object/data_loader.py | 30 + toolkit/dataloader_mixins.py | 81 ++ toolkit/lora_special.py | 11 +- toolkit/models/base_model.py | 5 + toolkit/models/lokr.py | 22 +- toolkit/network_mixins.py | 14 +- toolkit/stable_diffusion_model.py | 5 + ui/src/app/api/img/delete/route.ts | 2 +- ui/src/app/api/jobs/[jobID]/samples/route.ts | 2 +- ui/src/app/jobs/new/SimpleJob.tsx | 42 + ui/src/app/jobs/new/jobConfig.ts | 1 - ui/src/app/jobs/new/options.ts | 29 + ui/src/app/jobs/new/utils.ts | 55 +- ui/src/components/SampleImageCard.tsx | 14 +- ui/src/components/SampleImageViewer.tsx | 26 +- ui/src/docs.tsx | 29 + ui/src/types.ts | 6 +- ui/src/utils/hooks.tsx | 41 +- version.py | 2 +- 28 files changed, 2180 insertions(+), 71 deletions(-) create mode 100644 extensions_built_in/diffusion_models/ltx2/__init__.py create mode 100644 extensions_built_in/diffusion_models/ltx2/convert_ltx2_to_diffusers.py create mode 100644 extensions_built_in/diffusion_models/ltx2/ltx2.py create mode 100644 testing/test_ltx_dataloader.py create mode 100644 toolkit/audio/preserve_pitch.py diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 31fb1ac5..909c3ccb 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -7,6 +7,7 @@ from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel from .flux2 import Flux2Model from .z_image import ZImageModel +from .ltx2 import LTX2Model AI_TOOLKIT_MODELS = [ # put a list of models here @@ -25,4 +26,5 @@ AI_TOOLKIT_MODELS = [ QwenImageEditPlusModel, Flux2Model, ZImageModel, + LTX2Model, ] diff --git a/extensions_built_in/diffusion_models/ltx2/__init__.py b/extensions_built_in/diffusion_models/ltx2/__init__.py new file mode 100644 index 00000000..6ac67372 --- /dev/null +++ b/extensions_built_in/diffusion_models/ltx2/__init__.py @@ -0,0 +1 @@ +from .ltx2 import LTX2Model \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/ltx2/convert_ltx2_to_diffusers.py b/extensions_built_in/diffusion_models/ltx2/convert_ltx2_to_diffusers.py new file mode 100644 index 00000000..7ee8af2b --- /dev/null +++ b/extensions_built_in/diffusion_models/ltx2/convert_ltx2_to_diffusers.py @@ -0,0 +1,648 @@ +# ref https://github.com/huggingface/diffusers/blob/249ae1f853be8775c7d0b3a26ca4fbc4f6aa9998/scripts/convert_ltx2_to_diffusers.py + +from contextlib import nullcontext +from typing import Any, Dict, Tuple + +import torch +from accelerate import init_empty_weights + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + + +LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = { + # Input Patchify Projections + "patchify_proj": "proj_in", + "audio_patchify_proj": "audio_proj_in", + # Modulation Parameters + # Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are + # substrings of the other modulation parameters below + "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift", + "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate", + "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift", + "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate", + # Transformer Blocks + # Per-Block Cross Attention Modulatin Parameters + "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table", + "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +LTX_2_0_VIDEO_VAE_RENAME_DICT = { + # Encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # Decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + # Common + # For all 3D ResNets + "res_blocks": "resnets", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + +LTX_2_0_AUDIO_VAE_RENAME_DICT = { + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + +LTX_2_0_VOCODER_RENAME_DICT = { + "ups": "upsamplers", + "resblocks": "resnets", + "conv_pre": "conv_in", + "conv_post": "conv_out", +} + +LTX_2_0_TEXT_ENCODER_RENAME_DICT = { + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + + +def update_state_dict_inplace( + state_dict: Dict[str, Any], old_key: str, new_key: str +) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None: + state_dict.pop(key) + + +def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight, bias + if ".weight" not in key and ".bias" not in key: + return + + if key.startswith("adaln_single."): + new_key = key.replace("adaln_single.", "time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + if key.startswith("audio_adaln_single."): + new_key = key.replace("audio_adaln_single.", "audio_time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + +def convert_ltx2_audio_vae_per_channel_statistics( + key: str, state_dict: Dict[str, Any] +) -> None: + if key.startswith("per_channel_statistics"): + new_key = ".".join(["decoder", key]) + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + +LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "video_embeddings_connector": remove_keys_inplace, + "audio_embeddings_connector": remove_keys_inplace, + "adaln_single": convert_ltx2_transformer_adaln_single, +} + +LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + "text_embedding_projection.aggregate_embed": "text_proj_in", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_inplace, + "per_channel_statistics.mean-of-stds": remove_keys_inplace, +} + +LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {} + +LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} + + +def split_transformer_and_connector_state_dict( + state_dict: Dict[str, Any], +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + connector_prefixes = ( + "video_embeddings_connector", + "audio_embeddings_connector", + "transformer_1d_blocks", + "text_embedding_projection.aggregate_embed", + "connectors.", + "video_connector", + "audio_connector", + "text_proj_in", + ) + + transformer_state_dict, connector_state_dict = {}, {} + for key, value in state_dict.items(): + if key.startswith(connector_prefixes): + connector_state_dict[key] = value + else: + transformer_state_dict[key] = value + + return transformer_state_dict, connector_state_dict + + +def get_ltx2_transformer_config() -> Tuple[ + Dict[str, Any], Dict[str, Any], Dict[str, Any] +]: + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "vae_scale_factors": (8, 32, 32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "audio_in_channels": 128, + "audio_out_channels": 128, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 32, + "audio_attention_head_dim": 64, + "audio_cross_attention_dim": 2048, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 3840, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1000, + "rope_type": "split", + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def get_ltx2_connectors_config() -> Tuple[ + Dict[str, Any], Dict[str, Any], Dict[str, Any] +]: + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "caption_channels": 3840, + "text_proj_in_factor": 49, + "video_connector_num_attention_heads": 30, + "video_connector_attention_head_dim": 128, + "video_connector_num_layers": 2, + "video_connector_num_learnable_registers": 128, + "audio_connector_num_attention_heads": 30, + "audio_connector_attention_head_dim": 128, + "audio_connector_num_layers": 2, + "audio_connector_num_learnable_registers": 128, + "connector_rope_base_seq_len": 4096, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_temporal_positioning": False, + "rope_type": "split", + }, + } + + rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = {} + + return config, rename_dict, special_keys_remap + + +def convert_ltx2_transformer(original_state_dict: Dict[str, Any]) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_transformer_config() + diffusers_config = config["diffusers_config"] + + transformer_state_dict, _ = split_transformer_and_connector_state_dict( + original_state_dict + ) + + with init_empty_weights(): + transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(transformer_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(transformer_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(transformer_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, transformer_state_dict) + + transformer.load_state_dict(transformer_state_dict, strict=True, assign=True) + return transformer + + +def convert_ltx2_connectors(original_state_dict: Dict[str, Any]) -> LTX2TextConnectors: + config, rename_dict, special_keys_remap = get_ltx2_connectors_config() + diffusers_config = config["diffusers_config"] + + _, connector_state_dict = split_transformer_and_connector_state_dict( + original_state_dict + ) + if len(connector_state_dict) == 0: + raise ValueError("No connector weights found in the provided state dict.") + + with init_empty_weights(): + connectors = LTX2TextConnectors.from_config(diffusers_config) + + for key in list(connector_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(connector_state_dict, key, new_key) + + for key in list(connector_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, connector_state_dict) + + connectors.load_state_dict(connector_state_dict, strict=True, assign=True) + return connectors + + +def get_ltx2_video_vae_config() -> Tuple[ + Dict[str, Any], Dict[str, Any], Dict[str, Any] +]: + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ( + "spatial", + "temporal", + "spatiotemporal", + "spatiotemporal", + ), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_video_vae(original_state_dict: Dict[str, Any]) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config() + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Video.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_ltx2_audio_vae_config() -> Tuple[ + Dict[str, Any], Dict[str, Any], Dict[str, Any] +]: + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "base_channels": 128, + "output_channels": 2, + "ch_mult": (1, 2, 4), + "num_res_blocks": 2, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 256, + "latent_channels": 8, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 64, + "double_z": True, + }, + } + rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any]) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config() + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Audio.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_ltx2_vocoder_config() -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "hidden_channels": 1024, + "out_channels": 2, + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_factors": [6, 5, 2, 2, 2], + "resnet_kernel_sizes": [3, 7, 11], + "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "leaky_relu_negative_slope": 0.1, + "output_sampling_rate": 24000, + }, + } + rename_dict = LTX_2_0_VOCODER_RENAME_DICT + special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_vocoder(original_state_dict: Dict[str, Any]) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_vocoder_config() + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vocoder = LTX2Vocoder.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vocoder.load_state_dict(original_state_dict, strict=True, assign=True) + return vocoder + + +def get_model_state_dict_from_combined_ckpt( + combined_ckpt: Dict[str, Any], prefix: str +) -> Dict[str, Any]: + # Ensure that the key prefix ends with a dot (.) + if not prefix.endswith("."): + prefix = prefix + "." + + model_state_dict = {} + for param_name, param in combined_ckpt.items(): + if param_name.startswith(prefix): + model_state_dict[param_name.replace(prefix, "")] = param + + if prefix == "model.diffusion_model.": + # Some checkpoints store the text connector projection outside the diffusion model prefix. + connector_key = "text_embedding_projection.aggregate_embed.weight" + if connector_key in combined_ckpt and connector_key not in model_state_dict: + model_state_dict[connector_key] = combined_ckpt[connector_key] + + return model_state_dict + + +def dequantize_state_dict(state_dict: Dict[str, Any]): + keys = list(state_dict.keys()) + state_out = {} + for k in keys: + if k.endswith( + (".weight_scale", ".weight_scale_2", ".pre_quant_scale", ".input_scale") + ): + continue + + t = state_dict[k] + + if k.endswith(".weight"): + prefix = k[: -len(".weight")] + wscale_k = prefix + ".weight_scale" + if wscale_k in state_dict: + w_q = t + w_scale = state_dict[wscale_k] + # Comfy quant = absmax per-tensor weight quant, nothing fancy + w_bf16 = w_q.to(torch.bfloat16) * w_scale.to(torch.bfloat16) + state_out[k] = w_bf16 + continue + + state_out[k] = t + return state_out + + +def convert_comfy_gemma3_to_transformers(sd: dict): + out = {} + + sd = dequantize_state_dict(sd) + + for k, v in sd.items(): + nk = k + + # Vision tower weights: checkpoint has "vision_model.*" + # model expects "model.vision_tower.vision_model.*" + if k.startswith("vision_model."): + nk = "model.vision_tower." + k + + # MM projector: checkpoint has "multi_modal_projector.*" + # model expects "model.multi_modal_projector.*" + elif k.startswith("multi_modal_projector."): + nk = "model." + k + + # Language model: checkpoint has "model.layers.*", "model.embed_tokens.*", "model.norm.*" + # model expects "model.language_model.layers.*", etc. + elif k == "model.embed_tokens.weight": + nk = "model.language_model.embed_tokens.weight" + elif k.startswith("model.layers."): + nk = "model.language_model.layers." + k[len("model.layers.") :] + elif k.startswith("model.norm."): + nk = "model.language_model.norm." + k[len("model.norm.") :] + + # (optional) common DDP prefix + if nk.startswith("module."): + nk = nk[len("module.") :] + + # skip spiece_model + if nk == "spiece_model": + continue + + out[nk] = v + + # If lm_head is missing but embeddings exist, many Gemma-family models tie these weights. + # Add it so strict loading won't complain (or just load strict=False and call tie_weights()). + if ( + "lm_head.weight" not in out + and "model.language_model.embed_tokens.weight" in out + ): + out["lm_head.weight"] = out["model.language_model.embed_tokens.weight"] + + return out + + +def convert_lora_original_to_diffusers( + lora_state_dict: Dict[str, Any], +) -> Dict[str, Any]: + out: Dict[str, Any] = {} + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + + for k, v in lora_state_dict.items(): + # Keep the "diffusion_model." prefix as-is, but apply the transformer remaps to the rest + prefix = "" + rest = k + if rest.startswith("diffusion_model."): + prefix = "diffusion_model." + rest = rest[len(prefix) :] + + nk = rest + + # Same simple 1:1 remaps as the transformer + for replace_key, rename_key in rename_dict.items(): + nk = nk.replace(replace_key, rename_key) + + # Same special-case remap as the transformer (applies to LoRA keys too) + if nk.startswith("adaln_single."): + nk = nk.replace("adaln_single.", "time_embed.", 1) + elif nk.startswith("audio_adaln_single."): + nk = nk.replace("audio_adaln_single.", "audio_time_embed.", 1) + + out[prefix + nk] = v + + return out + + +def convert_lora_diffusers_to_original( + lora_state_dict: Dict[str, Any], +) -> Dict[str, Any]: + out: Dict[str, Any] = {} + + inv_rename = {v: k for k, v in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items()} + inv_items = sorted(inv_rename.items(), key=lambda kv: len(kv[0]), reverse=True) + + for k, v in lora_state_dict.items(): + # Keep the "diffusion_model." prefix as-is, but invert remaps on the rest + prefix = "" + rest = k + if rest.startswith("diffusion_model."): + prefix = "diffusion_model." + rest = rest[len(prefix) :] + + nk = rest + + # Inverse of the adaln_single special-case + if nk.startswith("time_embed."): + nk = nk.replace("time_embed.", "adaln_single.", 1) + elif nk.startswith("audio_time_embed."): + nk = nk.replace("audio_time_embed.", "audio_adaln_single.", 1) + + # Inverse 1:1 remaps + for diffusers_key, original_key in inv_items: + nk = nk.replace(diffusers_key, original_key) + + out[prefix + nk] = v + + return out diff --git a/extensions_built_in/diffusion_models/ltx2/ltx2.py b/extensions_built_in/diffusion_models/ltx2/ltx2.py new file mode 100644 index 00000000..ddd9568d --- /dev/null +++ b/extensions_built_in/diffusion_models/ltx2/ltx2.py @@ -0,0 +1,852 @@ +from functools import partial +import os +from typing import List, Optional + +import torch +import torchaudio +from transformers import Gemma3Config +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from accelerate import init_empty_weights +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager +from safetensors.torch import load_file + +try: + from diffusers import LTX2Pipeline + from diffusers.models.autoencoders import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + ) + from diffusers.models.transformers import LTX2VideoTransformer3DModel + from diffusers.pipelines.ltx2.export_utils import encode_video + from transformers import ( + Gemma3ForConditionalGeneration, + GemmaTokenizerFast, + ) + from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors + from .convert_ltx2_to_diffusers import ( + get_model_state_dict_from_combined_ckpt, + convert_ltx2_transformer, + convert_ltx2_video_vae, + convert_ltx2_audio_vae, + convert_ltx2_vocoder, + convert_ltx2_connectors, + dequantize_state_dict, + convert_comfy_gemma3_to_transformers, + convert_lora_original_to_diffusers, + convert_lora_diffusers_to_original, + ) +except ImportError as e: + print("Diffusers import error:", e) + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" + ) + + +scheduler_config = { + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} + +dit_prefix = "model.diffusion_model." +vae_prefix = "vae." +audio_vae_prefix = "audio_vae." +vocoder_prefix = "vocoder." + + +def new_save_image_function( + self: GenerateImageConfig, + image, # will contain a dict that can be dumped ditectly into encode_video, just add output_path to it. + count: int = 0, + max_count: int = 0, + **kwargs, +): + # this replaces gen image config save image function so we can save the video with sound from ltx2 + image["output_path"] = self.get_image_path(count, max_count) + # make sample directory if it does not exist + os.makedirs(os.path.dirname(image["output_path"]), exist_ok=True) + encode_video(**image) + flush() + + +def blank_log_image_function(self, *args, **kwargs): + # todo handle wandb logging of videos with audio + return + + +class AudioProcessor(torch.nn.Module): + """Converts audio waveforms to log-mel spectrograms with optional resampling.""" + + def __init__( + self, + sample_rate: int, + mel_bins: int, + mel_hop_length: int, + n_fft: int, + ) -> None: + super().__init__() + self.sample_rate = sample_rate + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + win_length=n_fft, + hop_length=mel_hop_length, + f_min=0.0, + f_max=sample_rate / 2.0, + n_mels=mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ) + + def resample_waveform( + self, + waveform: torch.Tensor, + source_rate: int, + target_rate: int, + ) -> torch.Tensor: + """Resample waveform to target sample rate if needed.""" + if source_rate == target_rate: + return waveform + resampled = torchaudio.functional.resample(waveform, source_rate, target_rate) + return resampled.to(device=waveform.device, dtype=waveform.dtype) + + def waveform_to_mel( + self, + waveform: torch.Tensor, + waveform_sample_rate: int, + ) -> torch.Tensor: + """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels].""" + waveform = self.resample_waveform( + waveform, waveform_sample_rate, self.sample_rate + ) + + mel = self.mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + + mel = mel.to(device=waveform.device, dtype=waveform.dtype) + return mel.permute(0, 1, 3, 2).contiguous() + + +class LTX2Model(BaseModel): + arch = "ltx2" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["LTX2VideoTransformer3DModel"] + # defines if the model supports model paths. Only some will + self.supports_model_paths = True + # use the new format on this new model by default + self.use_old_lokr_format = False + self.audio_processor = None + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 32 + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading LTX2 model") + model_path = self.model_config.name_or_path + base_model_path = self.model_config.extras_name_or_path + + combined_state_dict = None + + self.print_and_status_update("Loading transformer") + # if we have a safetensors file it is a mono checkpoint + if os.path.exists(model_path) and model_path.endswith(".safetensors"): + combined_state_dict = load_file(model_path) + combined_state_dict = dequantize_state_dict(combined_state_dict) + + if combined_state_dict is not None: + original_dit_ckpt = get_model_state_dict_from_combined_ckpt( + combined_state_dict, dit_prefix + ) + transformer = convert_ltx2_transformer(original_dit_ckpt) + transformer = transformer.to(dtype) + else: + transformer_path = model_path + transformer_subfolder = "transformer" + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = LTX2VideoTransformer3DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + ignore_modules = [] + for block in transformer.transformer_blocks: + ignore_modules.append(block.scale_shift_table) + ignore_modules.append(block.audio_scale_shift_table) + ignore_modules.append(block.video_a2v_cross_attn_scale_shift_table) + ignore_modules.append(block.audio_a2v_cross_attn_scale_shift_table) + ignore_modules.append(transformer.scale_shift_table) + ignore_modules.append(transformer.audio_scale_shift_table) + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=ignore_modules, + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + flush() + + self.print_and_status_update("Loading text encoder") + if ( + self.model_config.te_name_or_path is not None + and self.model_config.te_name_or_path.endswith(".safetensors") + ): + # load from comfyui gemma3 checkpoint + tokenizer = GemmaTokenizerFast.from_pretrained( + "Lightricks/LTX-2", subfolder="tokenizer" + ) + + with init_empty_weights(): + text_encoder = Gemma3ForConditionalGeneration( + Gemma3Config( + **{ + "boi_token_index": 255999, + "bos_token_id": 2, + "eoi_token_index": 256000, + "eos_token_id": 106, + "image_token_index": 262144, + "initializer_range": 0.02, + "mm_tokens_per_image": 256, + "model_type": "gemma3", + "pad_token_id": 0, + "text_config": { + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": None, + "cache_implementation": "hybrid", + "final_logit_softcapping": None, + "head_dim": 256, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 3840, + "initializer_range": 0.02, + "intermediate_size": 15360, + "max_position_embeddings": 131072, + "model_type": "gemma3_text", + "num_attention_heads": 16, + "num_hidden_layers": 48, + "num_key_value_heads": 8, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_local_base_freq": 10000, + "rope_scaling": {"factor": 8.0, "rope_type": "linear"}, + "rope_theta": 1000000, + "sliding_window": 1024, + "sliding_window_pattern": 6, + "torch_dtype": "bfloat16", + "use_cache": True, + "vocab_size": 262208, + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.3", + "unsloth_fixed": True, + "vision_config": { + "attention_dropout": 0.0, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 896, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 27, + "patch_size": 14, + "torch_dtype": "bfloat16", + "vision_use_head": False, + }, + } + ) + ) + te_state_dict = load_file(self.model_config.te_name_or_path) + te_state_dict = convert_comfy_gemma3_to_transformers(te_state_dict) + for key in te_state_dict: + te_state_dict[key] = te_state_dict[key].to(dtype) + + text_encoder.load_state_dict(te_state_dict, assign=True, strict=True) + del te_state_dict + flush() + else: + if self.model_config.te_name_or_path is not None: + te_path = self.model_config.te_name_or_path + else: + te_path = base_model_path + tokenizer = GemmaTokenizerFast.from_pretrained( + te_path, subfolder="tokenizer" + ) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + te_path, subfolder="text_encoder", dtype=dtype + ) + + # remove the vision tower + text_encoder.model.vision_tower = None + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ignore_modules=[ + text_encoder.model.language_model.base_model.embed_tokens + ], + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + self.print_and_status_update("Loading VAEs and other components") + if combined_state_dict is not None: + original_vae_ckpt = get_model_state_dict_from_combined_ckpt( + combined_state_dict, vae_prefix + ) + vae = convert_ltx2_video_vae(original_vae_ckpt).to(dtype) + del original_vae_ckpt + original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt( + combined_state_dict, audio_vae_prefix + ) + audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt).to(dtype) + del original_audio_vae_ckpt + original_connectors_ckpt = get_model_state_dict_from_combined_ckpt( + combined_state_dict, dit_prefix + ) + connectors = convert_ltx2_connectors(original_connectors_ckpt).to(dtype) + del original_connectors_ckpt + original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt( + combined_state_dict, vocoder_prefix + ) + vocoder = convert_ltx2_vocoder(original_vocoder_ckpt).to(dtype) + del original_vocoder_ckpt + del combined_state_dict + flush() + else: + vae = AutoencoderKLLTX2Video.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype + ) + audio_vae = AutoencoderKLLTX2Audio.from_pretrained( + base_model_path, subfolder="audio_vae", torch_dtype=dtype + ) + + connectors = LTX2TextConnectors.from_pretrained( + base_model_path, subfolder="connectors", torch_dtype=dtype + ) + + vocoder = LTX2Vocoder.from_pretrained( + base_model_path, subfolder="vocoder", torch_dtype=dtype + ) + + self.noise_scheduler = LTX2Model.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + pipe: LTX2Pipeline = LTX2Pipeline( + scheduler=self.noise_scheduler, + vae=vae, + audio_vae=audio_vae, + text_encoder=None, + tokenizer=tokenizer, + connectors=connectors, + transformer=None, + vocoder=vocoder, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + + self.audio_processor = AudioProcessor( + sample_rate=pipe.audio_sampling_rate, + mel_bins=audio_vae.config.mel_bins, + mel_hop_length=pipe.audio_hop_length, + n_fft=1024, # todo get this from vae if we can, I couldnt find it. + ).to(self.device_torch, dtype=torch.float32) + + self.print_and_status_update("Model Loaded") + + @torch.no_grad() + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + + image_list = [image.to(device, dtype=dtype) for image in image_list] + + # Normalize shapes + norm_images = [] + for image in image_list: + if image.ndim == 3: + # (C, H, W) -> (C, 1, H, W) + norm_images.append(image.unsqueeze(1)) + elif image.ndim == 4: + # (T, C, H, W) -> (C, T, H, W) + norm_images.append(image.permute(1, 0, 2, 3)) + else: + raise ValueError(f"Invalid image shape: {image.shape}") + + # Stack to (B, C, T, H, W) + images = torch.stack(norm_images) + + latents = self.vae.encode(images).latent_dist.mode() + + # Normalize latents across the channel dimension [B, C, F, H, W] + scaling_factor = 1.0 + latents_mean = self.pipeline.vae.latents_mean.view(1, -1, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents_std = self.pipeline.vae.latents_std.view(1, -1, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = (latents - latents_mean) * scaling_factor / latents_std + + return latents.to(device, dtype=dtype) + + def get_generation_pipeline(self): + scheduler = LTX2Model.get_train_scheduler() + + pipeline: LTX2Pipeline = LTX2Pipeline( + scheduler=scheduler, + vae=unwrap_model(self.pipeline.vae), + audio_vae=unwrap_model(self.pipeline.audio_vae), + text_encoder=None, + tokenizer=unwrap_model(self.pipeline.tokenizer), + connectors=unwrap_model(self.pipeline.connectors), + transformer=None, + vocoder=unwrap_model(self.pipeline.vocoder), + ) + pipeline.transformer = unwrap_model(self.model) + pipeline.text_encoder = unwrap_model(self.text_encoder[0]) + + # if self.low_vram: + # pipeline.enable_model_cpu_offload(device=self.device_torch) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: LTX2Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + is_video = gen_config.num_frames > 1 + # override the generate single image to handle video + audio generation + if is_video: + gen_config._orig_save_image_function = gen_config.save_image + gen_config.save_image = partial(new_save_image_function, gen_config) + gen_config.log_image = partial(blank_log_image_function, gen_config) + # set output extension to mp4 + gen_config.output_ext = "mp4" + + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + pipeline = pipeline.to(self.device_torch) + + # make sure dimensions are valid + bd = self.get_bucket_divisibility() + gen_config.height = (gen_config.height // bd) * bd + gen_config.width = (gen_config.width // bd) * bd + + # frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc. + if gen_config.num_frames != 1: + if (gen_config.num_frames - 1) % 8 != 0: + gen_config.num_frames = ((gen_config.num_frames - 1) // 8) * 8 + 1 + + if self.low_vram: + # set vae to tile decode + pipeline.vae.enable_tiling( + tile_sample_min_height=256, + tile_sample_min_width=256, + tile_sample_min_num_frames=8, + tile_sample_stride_height=224, + tile_sample_stride_width=224, + tile_sample_stride_num_frames=4, + ) + + video, audio = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + prompt_attention_mask=conditional_embeds.attention_mask.to( + self.device_torch + ), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to( + self.device_torch + ), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="np" if is_video else "pil", + **extra, + ) + if self.low_vram: + # Restore no tiling + pipeline.vae.use_tiling = False + + if is_video: + # redurn as a dict, we will handle it with an override function + video = (video * 255).round().astype("uint8") + video = torch.from_numpy(video) + return { + "video": video[0], + "fps": gen_config.fps, + "audio": audio[0].float().cpu(), + "audio_sample_rate": pipeline.vocoder.config.output_sampling_rate, # should be 24000 + "output_path": None, + } + else: + # shape = [1, frames, channels, height, width] + # make sure this is right + video = video[0] # list of pil images + audio = audio[0] # tensor + if gen_config.num_frames > 1: + return video # return the frames. + else: + # get just the first image + img = video[0] + return img + + def encode_audio(self, batch: "DataLoaderBatchDTO"): + if self.pipeline.audio_vae.device == torch.device("cpu"): + self.pipeline.audio_vae.to(self.device_torch) + + output_tensor = None + audio_num_frames = None + + # do them seperatly for now + for audio_data in batch.audio_data: + waveform = audio_data["waveform"].to( + device=self.device_torch, dtype=torch.float32 + ) + sample_rate = audio_data["sample_rate"] + + # Add batch dimension if needed: [channels, samples] -> [batch, channels, samples] + if waveform.dim() == 2: + waveform = waveform.unsqueeze(0) + + if waveform.shape[1] == 1: + # make sure it is stereo + waveform = waveform.repeat(1, 2, 1) + + # Convert waveform to mel spectrogram using AudioProcessor + mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate) + mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype) + + # Encode mel spectrogram to latents + latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode() + + if audio_num_frames is None: + audio_num_frames = latents.shape[2] #(latents is [B, C, T, F]) + + packed_latents = self.pipeline._pack_audio_latents( + latents, + # patch_size=self.pipeline.transformer.config.audio_patch_size, + # patch_size_t=self.pipeline.transformer.config.audio_patch_size_t, + ) # [B, L, C * M] + if output_tensor is None: + output_tensor = packed_latents + else: + output_tensor = torch.cat([output_tensor, packed_latents], dim=0) + + # normalize latents, opposite of (latents * latents_std) + latents_mean + latents_mean = self.pipeline.audio_vae.latents_mean + latents_std = self.pipeline.audio_vae.latents_std + output_tensor = (output_tensor - latents_mean) / latents_std + return output_tensor, audio_num_frames + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: "DataLoaderBatchDTO" = None, + **kwargs, + ): + with torch.no_grad(): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + batch_size, C, latent_num_frames, latent_height, latent_width = ( + latent_model_input.shape + ) + + # todo get this somehow + frame_rate = 24 + # check frame dimension + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + packed_latents = self.pipeline._pack_latents( + latent_model_input, + patch_size=self.pipeline.transformer_spatial_patch_size, + patch_size_t=self.pipeline.transformer_temporal_patch_size, + ) + + if batch.audio_tensor is not None: + # use audio from the batch if available + #(1, 190, 128) + raw_audio_latents, audio_num_frames = self.encode_audio(batch) + # add the audio targets to the batch for loss calculation later + audio_noise = torch.randn_like(raw_audio_latents) + batch.audio_target = (audio_noise - raw_audio_latents).detach() + audio_latents = self.add_noise( + raw_audio_latents, + audio_noise, + timestep, + ).to(self.device_torch, dtype=self.torch_dtype) + + else: + # no audio + num_mel_bins = self.pipeline.audio_vae.config.mel_bins + # latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.pipeline.audio_vae.config.latent_channels + ) + # audio latents are (1, 126, 128), audio_num_frames = 126 + audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents( + batch_size, + num_channels_latents=num_channels_latents_audio, + num_mel_bins=num_mel_bins, + num_frames=batch.tensor.shape[1], + frame_rate=frame_rate, + sampling_rate=self.pipeline.audio_sampling_rate, + hop_length=self.pipeline.audio_hop_length, + dtype=torch.float32, + device=self.transformer.device, + generator=None, + latents=None, + ) + + if self.pipeline.connectors.device != self.transformer.device: + self.pipeline.connectors.to(self.transformer.device) + + # TODO this is how diffusers does this on inference, not sure I understand why, check this + additive_attention_mask = ( + 1 - text_embeddings.attention_mask.to(self.transformer.dtype) + ) * -1000000.0 + ( + connector_prompt_embeds, + connector_audio_prompt_embeds, + connector_attention_mask, + ) = self.pipeline.connectors( + text_embeddings.text_embeds, additive_attention_mask, additive_mask=True + ) + + # compute video and audio positional ids + video_coords = self.transformer.rope.prepare_video_coords( + packed_latents.shape[0], + latent_num_frames, + latent_height, + latent_width, + packed_latents.device, + fps=frame_rate, + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=packed_latents, + audio_hidden_states=audio_latents.to(self.transformer.dtype), + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=None, + return_dict=False, + ) + + # add audio latent to batch if we had audio + if batch.audio_target is not None: + batch.audio_pred = noise_pred_audio + + unpacked_output = self.pipeline._unpack_latents( + latents=noise_pred_video, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + patch_size=self.pipeline.transformer_spatial_patch_size, + patch_size_t=self.pipeline.transformer_temporal_patch_size, + ) + + return unpacked_output + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt( + prompt, + do_classifier_free_guidance=False, + device=self.device_torch, + ) + pe = PromptEmbeds([prompt_embeds, None]) + pe.attention_mask = prompt_attention_mask + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + transformer: LTX2VideoTransformer3DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return "ltx2" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["transformer_blocks"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + new_sd = convert_lora_diffusers_to_original(new_sd) + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + state_dict = convert_lora_original_to_diffusers(state_dict) + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 2b504e21..152cd131 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -859,6 +859,11 @@ class SDTrainer(BaseSDTrainProcess): loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() + + # check for audio loss + if batch.audio_pred is not None and batch.audio_target is not None: + audio_loss = torch.nn.functional.mse_loss(batch.audio_pred.float(), batch.audio_target.float(), reduction="mean") + loss = loss + audio_loss # check for additional losses if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None: diff --git a/requirements.txt b/requirements.txt index 97d88e92..2ab3621d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torchao==0.10.0 safetensors -git+https://github.com/huggingface/diffusers@f6b6a7181eb44f0120b29cd897c129275f366c2a +git+https://github.com/huggingface/diffusers@8600b4c10d67b0ce200f664204358747bd53c775 transformers==4.57.3 lycoris-lora==1.8.3 flatten_json @@ -36,4 +36,6 @@ opencv-python pytorch-wavelets==1.3.0 matplotlib==3.10.1 setuptools==69.5.1 -scipy==1.12.0 \ No newline at end of file +scipy==1.12.0 +av==16.0.1 +torchcodec \ No newline at end of file diff --git a/testing/test_ltx_dataloader.py b/testing/test_ltx_dataloader.py new file mode 100644 index 00000000..d27426d5 --- /dev/null +++ b/testing/test_ltx_dataloader.py @@ -0,0 +1,234 @@ +import time + +from torch.utils.data import DataLoader +import sys +import os +import argparse +from tqdm import tqdm +import torch +from torchvision.io import write_video +import subprocess + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch +from toolkit.config_modules import DatasetConfig + +parser = argparse.ArgumentParser() +# parser.add_argument('dataset_folder', type=str, default='input') +parser.add_argument('dataset_folder', type=str) +parser.add_argument('--epochs', type=int, default=1) +parser.add_argument('--num_frames', type=int, default=121) +parser.add_argument('--output_path', type=str, default='output/dataset_test') + + +args = parser.parse_args() + +if args.output_path is None: + raise ValueError('output_path is required for this test script') + +if args.output_path is not None: + args.output_path = os.path.abspath(args.output_path) + os.makedirs(args.output_path, exist_ok=True) + +dataset_folder = args.dataset_folder +resolution = 512 +bucket_tolerance = 64 +batch_size = 1 +frame_rate = 24 + + +## make fake sd +class FakeSD: + def __init__(self): + self.use_raw_control_images = False + + def encode_control_in_text_embeddings(self, *args, **kwargs): + return None + + def get_bucket_divisibility(self): + return 32 + +dataset_config = DatasetConfig( + dataset_path=dataset_folder, + resolution=resolution, + default_caption='default', + buckets=True, + bucket_tolerance=bucket_tolerance, + shrink_video_to_frames=True, + num_frames=args.num_frames, + do_i2v=True, + fps=frame_rate, + do_audio=True, + debug=True, + audio_preserve_pitch=False, + audio_normalize=True + +) + +dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD()) + + +def _tensor_to_uint8_video(frames_fchw: torch.Tensor) -> torch.Tensor: + """ + frames_fchw: [F, C, H, W] float/uint8 + returns: [F, H, W, C] uint8 on CPU + """ + x = frames_fchw.detach() + + if x.dtype != torch.uint8: + x = x.to(torch.float32) + + # Heuristic: if negatives exist, assume [-1,1] normalization; else assume [0,1] + if torch.isfinite(x).all(): + if x.min().item() < 0.0: + x = x * 0.5 + 0.5 + x = x.clamp(0.0, 1.0) + x = (x * 255.0).round().to(torch.uint8) + else: + x = x.to(torch.uint8) + + # [F,C,H,W] -> [F,H,W,C] + x = x.permute(0, 2, 3, 1).contiguous().cpu() + return x + + +def _mux_with_ffmpeg(video_in: str, wav_in: str, mp4_out: str): + # Copy video stream, encode audio to AAC, align to shortest + subprocess.run( + [ + "ffmpeg", + "-y", + "-hide_banner", + "-loglevel", + "error", + "-i", + video_in, + "-i", + wav_in, + "-c:v", + "copy", + "-c:a", + "aac", + "-shortest", + mp4_out, + ], + check=True, + ) + + +# run through an epoch ang check sizes +dataloader_iterator = iter(dataloader) +idx = 0 +for epoch in range(args.epochs): + for batch in tqdm(dataloader): + batch: 'DataLoaderBatchDTO' + img_batch = batch.tensor + frames = 1 + if len(img_batch.shape) == 5: + frames = img_batch.shape[1] + batch_size, frames, channels, height, width = img_batch.shape + else: + batch_size, channels, height, width = img_batch.shape + + # load audio + audio_tensor = batch.audio_tensor # all file items contatinated on the batch dimension + audio_data = batch.audio_data # list of raw audio data per item in the batch + + # llm save the videos here with audio and video as mp4 + fps = getattr(dataset_config, "fps", None) + if fps is None or fps <= 0: + fps = 1.0 + + # Ensure we can iterate items even if batch_size > 1 + for b in range(batch_size): + # Get per-item frames as [F,C,H,W] + if len(img_batch.shape) == 5: + frames_fchw = img_batch[b] + else: + # single image: [C,H,W] -> [1,C,H,W] + frames_fchw = img_batch[b].unsqueeze(0) + + video_uint8 = _tensor_to_uint8_video(frames_fchw) + out_mp4 = os.path.join(args.output_path, f"{idx:06d}_{b:02d}.mp4") + + # Pick audio for this item (prefer audio_data list; fallback to audio_tensor) + item_audio = None + item_sr = None + + if isinstance(audio_data, (list, tuple)) and len(audio_data) > b: + ad = audio_data[b] + if isinstance(ad, dict) and ("waveform" in ad) and ("sample_rate" in ad) and ad["waveform"] is not None: + item_audio = ad["waveform"] + item_sr = int(ad["sample_rate"]) + elif audio_tensor is not None and torch.is_tensor(audio_tensor): + # audio_tensor expected [B, C, L] (or [C,L] if batch collate differs) + if audio_tensor.dim() == 3 and audio_tensor.shape[0] > b: + item_audio = audio_tensor[b] + elif audio_tensor.dim() == 2 and b == 0: + item_audio = audio_tensor + if item_audio is not None: + # best-effort sample rate from audio_data if present but not per-item dict + if isinstance(audio_data, dict) and "sample_rate" in audio_data: + try: + item_sr = int(audio_data["sample_rate"]) + except Exception: + item_sr = None + + # Write mp4 (with audio if available) using ffmpeg muxing (torchvision audio muxing is unreliable) + tmp_video = out_mp4 + ".tmp_video.mp4" + tmp_wav = out_mp4 + ".tmp_audio.wav" + try: + # Always write video-only first + write_video(tmp_video, video_uint8, fps=float(fps), video_codec="libx264") + + if item_audio is not None and item_sr is not None and item_audio.numel() > 0: + import torchaudio + + wav = item_audio.detach() + # torchaudio.save expects [channels, samples] + if wav.dim() == 1: + wav = wav.unsqueeze(0) + torchaudio.save(tmp_wav, wav.cpu().to(torch.float32), int(item_sr)) + + # Mux to final mp4 + _mux_with_ffmpeg(tmp_video, tmp_wav, out_mp4) + else: + # No audio: just move video into place + os.replace(tmp_video, out_mp4) + + except Exception as e: + # Best-effort fallback: leave a playable video-only file + try: + if os.path.exists(tmp_video): + os.replace(tmp_video, out_mp4) + else: + write_video(out_mp4, video_uint8, fps=float(fps), video_codec="libx264") + except Exception: + raise + + if hasattr(dataset_config, 'debug') and dataset_config.debug: + print(f"Warning: failed to mux audio into mp4 for {out_mp4}: {e}") + + finally: + # Cleanup temps (don't leave separate wavs lying around) + try: + if os.path.exists(tmp_video): + os.remove(tmp_video) + except Exception: + pass + try: + if os.path.exists(tmp_wav): + os.remove(tmp_wav) + except Exception: + pass + + time.sleep(0.2) + + idx += 1 + # if not last epoch + if epoch < args.epochs - 1: + trigger_dataloader_setup_epoch(dataloader) + +print('done') diff --git a/toolkit/audio/preserve_pitch.py b/toolkit/audio/preserve_pitch.py new file mode 100644 index 00000000..501c139d --- /dev/null +++ b/toolkit/audio/preserve_pitch.py @@ -0,0 +1,75 @@ +import math +import torch +import torch.nn.functional as F +import torchaudio + +def time_stretch_preserve_pitch(waveform: torch.Tensor, sample_rate: int, target_samples: int) -> torch.Tensor: + """ + waveform: [C, L] float tensor (CPU or GPU) + returns: [C, target_samples] float tensor + Pitch-preserving time stretch to match target_samples. + """ + + + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + + waveform = waveform.to(torch.float32) + + src_len = waveform.shape[-1] + if src_len == 0 or target_samples <= 0: + return waveform[..., :0] + + if src_len == target_samples: + return waveform + + # rate > 1.0 speeds up (shorter), rate < 1.0 slows down (longer) + rate = float(src_len) / float(target_samples) + + # Use sample_rate to pick STFT params + win_seconds = 0.046 + hop_seconds = 0.0115 + + n_fft_target = int(sample_rate * win_seconds) + n_fft = 1 << max(8, int(math.floor(math.log2(max(256, n_fft_target))))) # >=256, pow2 + win_length = n_fft + hop_length = max(64, int(sample_rate * hop_seconds)) + hop_length = min(hop_length, win_length // 2) + + window = torch.hann_window(win_length, device=waveform.device, dtype=waveform.dtype) + + stft = torch.stft( + waveform, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=True, + return_complex=True, + ) # [C, F, T] complex + + # IMPORTANT: n_freq must match STFT's frequency bins (n_fft//2 + 1) + stretcher = torchaudio.transforms.TimeStretch( + n_freq=stft.shape[-2], + hop_length=hop_length, + fixed_rate=rate, + ).to(waveform.device) + + stft_stretched = stretcher(stft) # [C, F, T'] + + stretched = torch.istft( + stft_stretched, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=True, + length=target_samples, + ) + + if stretched.shape[-1] > target_samples: + stretched = stretched[..., :target_samples] + elif stretched.shape[-1] < target_samples: + stretched = F.pad(stretched, (0, target_samples - stretched.shape[-1])) + + return stretched diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 7d2895cd..974e0cb6 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -207,6 +207,9 @@ class NetworkConfig: # -1 automatically finds the largest factor self.lokr_factor = kwargs.get('lokr_factor', -1) + # Use the old lokr format + self.old_lokr_format = kwargs.get('old_lokr_format', False) + # for multi stage models self.split_multistage_loras = kwargs.get('split_multistage_loras', True) @@ -672,6 +675,9 @@ class ModelConfig: # kwargs to pass to the model self.model_kwargs = kwargs.get("model_kwargs", {}) + # model paths for models that support it + self.model_paths = kwargs.get("model_paths", {}) + # allow frontend to pass arch with a color like arch:tag # but remove the tag if self.arch is not None: @@ -956,7 +962,7 @@ class DatasetConfig: # it will select a random start frame and pull the frames at the given fps # this could have various issues with shorter videos and videos with variable fps # I recommend trimming your videos to the desired length and using shrink_video_to_frames(default) - self.fps: int = kwargs.get('fps', 16) + self.fps: int = kwargs.get('fps', 24) # debug the frame count and frame selection. You dont need this. It is for debugging. self.debug: bool = kwargs.get('debug', False) @@ -972,6 +978,9 @@ class DatasetConfig: self.fast_image_size: bool = kwargs.get('fast_image_size', False) self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable + self.do_audio: bool = kwargs.get('do_audio', False) # load audio from video files for models that support it + self.audio_preserve_pitch: bool = kwargs.get('audio_preserve_pitch', False) # preserve pitch when stretching audio to fit num_frames + self.audio_normalize: bool = kwargs.get('audio_normalize', False) # normalize audio volume levels when loading def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index d72364bd..5e76563e 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -123,9 +123,13 @@ class FileItemDTO( self.is_reg = self.dataset_config.is_reg self.prior_reg = self.dataset_config.prior_reg self.tensor: Union[torch.Tensor, None] = None + self.audio_data = None + self.audio_tensor = None def cleanup(self): self.tensor = None + self.audio_data = None + self.audio_tensor = None self.cleanup_latent() self.cleanup_text_embedding() self.cleanup_control() @@ -154,6 +158,13 @@ class DataLoaderBatchDTO: self.clip_image_embeds_unconditional: Union[List[dict], None] = None self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None + self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None + self.audio_tensor: Union[torch.Tensor, None] = None + + # just for holding noise and preds during training + self.audio_target: Union[torch.Tensor, None] = None + self.audio_pred: Union[torch.Tensor, None] = None + if not is_latents_cached: # only return a tensor if latents are not cached self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) @@ -304,6 +315,21 @@ class DataLoaderBatchDTO: y.text_embeds = [y.text_embeds] prompt_embeds_list.append(y) self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list) + + if any([x.audio_tensor is not None for x in self.file_items]): + # find one to use as a base + base_audio_tensor = None + for x in self.file_items: + if x.audio_tensor is not None: + base_audio_tensor = x.audio_tensor + break + audio_tensors = [] + for x in self.file_items: + if x.audio_tensor is None: + audio_tensors.append(torch.zeros_like(base_audio_tensor)) + else: + audio_tensors.append(x.audio_tensor) + self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors]) except Exception as e: @@ -336,6 +362,10 @@ class DataLoaderBatchDTO: del self.latents del self.tensor del self.control_tensor + del self.audio_tensor + del self.audio_data + del self.audio_target + del self.audio_pred for file_item in self.file_items: file_item.cleanup() diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 3490806b..3b542f56 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -16,6 +16,7 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor +from toolkit.audio.preserve_pitch import time_stretch_preserve_pitch from toolkit.basic import flush, value_map from toolkit.buckets import get_bucket_for_image_size, get_resolution from toolkit.config_modules import ControlTypes @@ -467,6 +468,8 @@ class ImageProcessingDTOMixin: if not self.dataset_config.buckets: raise Exception('Buckets required for video processing') + do_audio = self.dataset_config.do_audio + try: # Use OpenCV to capture video frames cap = cv2.VideoCapture(self.path) @@ -596,6 +599,84 @@ class ImageProcessingDTOMixin: # Stack frames into tensor [frames, channels, height, width] self.tensor = torch.stack(frames) + + # ------------------------------ + # Audio extraction + stretching + # ------------------------------ + if do_audio: + # Default to "no audio" unless we successfully extract it + self.audio_data = None + self.audio_tensor = None + + try: + import torchaudio + import torch.nn.functional as F + + # Compute the time range of the selected frames in the *source* video + # Include the last frame by extending to the next frame boundary. + if video_fps and video_fps > 0 and len(frames_to_extract) > 0: + clip_start_frame = int(frames_to_extract[0]) + clip_end_frame = int(frames_to_extract[-1]) + clip_start_time = clip_start_frame / float(video_fps) + clip_end_time = (clip_end_frame + 1) / float(video_fps) + source_duration = max(0.0, clip_end_time - clip_start_time) + else: + clip_start_time = 0.0 + clip_end_time = 0.0 + source_duration = 0.0 + + # Target duration is how this sampled/stretched clip is interpreted for training + # (i.e. num_frames at the configured dataset FPS). + if hasattr(self.dataset_config, "fps") and self.dataset_config.fps and self.dataset_config.fps > 0: + target_duration = float(self.dataset_config.num_frames) / float(self.dataset_config.fps) + else: + target_duration = source_duration + + waveform, sample_rate = torchaudio.load(self.path) # [channels, samples] + + if self.dataset_config.audio_normalize: + peak = waveform.abs().amax() # global peak across channels + eps = 1e-9 + target_peak = 0.999 # ~ -0.01 dBFS + gain = target_peak / (peak + eps) + waveform = waveform * gain + + # Slice to the selected clip region (when we have a meaningful time range) + if source_duration > 0.0: + start_sample = int(round(clip_start_time * sample_rate)) + end_sample = int(round(clip_end_time * sample_rate)) + start_sample = max(0, min(start_sample, waveform.shape[-1])) + end_sample = max(0, min(end_sample, waveform.shape[-1])) + if end_sample > start_sample: + waveform = waveform[..., start_sample:end_sample] + else: + # No valid audio segment + waveform = None + else: + # If we can't compute a meaningful time range, treat as no-audio + waveform = None + + if waveform is not None and waveform.numel() > 0: + target_samples = int(round(target_duration * sample_rate)) + if target_samples > 0 and waveform.shape[-1] != target_samples: + # Time-stretch/shrink to match the video clip duration implied by dataset FPS. + if self.dataset_config.audio_preserve_pitch: + waveform = time_stretch_preserve_pitch(waveform, sample_rate, target_samples) # waveform is [C, L] + else: + # Use linear interpolation over the time axis. + wf = waveform.unsqueeze(0) # [1, C, L] + wf = F.interpolate(wf, size=target_samples, mode="linear", align_corners=False) + waveform = wf.squeeze(0) # [C, L] + + self.audio_tensor = waveform + self.audio_data = {"waveform": waveform, "sample_rate": int(sample_rate)} + + except Exception as e: + # Keep behavior identical for non-audio datasets; for audio datasets, just skip if missing/broken. + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + print_acc(f"Could not extract/stretch audio for {self.path}: {e}") + self.audio_data = None + self.audio_tensor = None # Only log success in debug mode if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index cd454656..5cb19229 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -265,11 +265,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.peft_format = peft_format self.is_transformer = is_transformer + # use the old format for older models unless the user has specified otherwise + self.use_old_lokr_format = False + if self.network_config is not None and hasattr(self.network_config, 'old_lokr_format'): + self.use_old_lokr_format = self.network_config.old_lokr_format + # also allow a false from the model itself + if base_model is not None and not base_model.use_old_lokr_format: + self.use_old_lokr_format = False # always do peft for flux only for now if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer: - # don't do peft format for lokr - if self.network_type.lower() != "lokr": + # don't do peft format for lokr if using old format + if self.network_type.lower() != "lokr" or not self.use_old_lokr_format: self.peft_format = True if self.peft_format: diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 9671a8c4..23bd9a9c 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -185,6 +185,11 @@ class BaseModel: self.has_multiple_control_images = False # do not resize control images self.use_raw_control_images = False + # defines if the model supports model paths. Only some will + self.supports_model_paths = False + + # use new lokr format (default false for old models for backwards compatibility) + self.use_old_lokr_format = True # properties for old arch for backwards compatibility @property diff --git a/toolkit/models/lokr.py b/toolkit/models/lokr.py index 3d7e6ca6..486c0121 100644 --- a/toolkit/models/lokr.py +++ b/toolkit/models/lokr.py @@ -11,6 +11,7 @@ from toolkit.network_mixins import ToolkitModuleMixin from typing import TYPE_CHECKING, Union, List from optimum.quanto import QBytesTensor, QTensor +from torchao.dtypes import AffineQuantizedTensor if TYPE_CHECKING: @@ -284,17 +285,26 @@ class LokrModule(ToolkitModuleMixin, nn.Module): org_sd[weight_key] = merged_weight.to(orig_dtype) self.org_module[0].load_state_dict(org_sd) - def get_orig_weight(self): + def get_orig_weight(self, device): weight = self.org_module[0].weight + if weight.device != device: + weight = weight.to(device) if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor): return weight.dequantize().data.detach() + elif isinstance(weight, AffineQuantizedTensor): + return weight.dequantize().data.detach() else: return weight.data.detach() - def get_orig_bias(self): + def get_orig_bias(self, device): if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None: - if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor): - return self.org_module[0].bias.dequantize().data.detach() + bias = self.org_module[0].bias + if bias.device != device: + bias = bias.to(device) + if isinstance(bias, QTensor) or isinstance(bias, QBytesTensor): + return bias.dequantize().data.detach() + elif isinstance(bias, AffineQuantizedTensor): + return bias.dequantize().data.detach() else: return self.org_module[0].bias.data.detach() return None @@ -305,7 +315,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module): orig_dtype = x.dtype - orig_weight = self.get_orig_weight() + orig_weight = self.get_orig_weight(x.device) lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype) multiplier = self.network_ref().torch_multiplier @@ -319,7 +329,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module): orig_weight + lokr_weight * multiplier ) - bias = self.get_orig_bias() + bias = self.get_orig_bias(x.device) if bias is not None: bias = bias.to(weight.device, dtype=weight.dtype) output = self.op( diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 0cea3013..b8556f12 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -546,7 +546,8 @@ class ToolkitNetworkMixin: new_save_dict = {} for key, value in save_dict.items(): - if key.endswith('.alpha'): + # lokr needs alpha + if key.endswith('.alpha') and self.network_type.lower() != "lokr": continue new_key = key new_key = new_key.replace('lora_down', 'lora_A') @@ -558,7 +559,7 @@ class ToolkitNetworkMixin: save_dict = new_save_dict - if self.network_type.lower() == "lokr": + if self.network_type.lower() == "lokr" and self.use_old_lokr_format: new_save_dict = {} for key, value in save_dict.items(): # lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1 @@ -632,7 +633,7 @@ class ToolkitNetworkMixin: # lora_down = lora_A # lora_up = lora_B # no alpha - if load_key.endswith('.alpha'): + if load_key.endswith('.alpha') and self.network_type.lower() != "lokr": continue load_key = load_key.replace('lora_A', 'lora_down') load_key = load_key.replace('lora_B', 'lora_up') @@ -640,6 +641,13 @@ class ToolkitNetworkMixin: load_key = load_key.replace('.', '$$') load_key = load_key.replace('$$lora_down$$', '.lora_down.') load_key = load_key.replace('$$lora_up$$', '.lora_up.') + + # patch lokr, not sure why we need to but whatever + if self.network_type.lower() == "lokr": + load_key = load_key.replace('$$lokr_w1', '.lokr_w1') + load_key = load_key.replace('$$lokr_w2', '.lokr_w2') + if load_key.endswith('$$alpha'): + load_key = load_key[:-7] + '.alpha' if self.network_type.lower() == "lokr": # lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1 diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index ab9a57f5..d2b34b9f 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -223,6 +223,11 @@ class StableDiffusion: self.has_multiple_control_images = False # do not resize control images self.use_raw_control_images = False + # defines if the model supports model paths. Only some will + self.supports_model_paths = False + + # use new lokr format (default false for old models for backwards compatibility) + self.use_old_lokr_format = True # properties for old arch for backwards compatibility @property diff --git a/ui/src/app/api/img/delete/route.ts b/ui/src/app/api/img/delete/route.ts index d213c1c6..3db7019a 100644 --- a/ui/src/app/api/img/delete/route.ts +++ b/ui/src/app/api/img/delete/route.ts @@ -15,7 +15,7 @@ export async function POST(request: Request) { } // make sure it is an image - if (!/\.(jpg|jpeg|png|bmp|gif|tiff|webp)$/i.test(imgPath.toLowerCase())) { + if (!/\.(jpg|jpeg|png|bmp|gif|tiff|webp|mp4)$/i.test(imgPath.toLowerCase())) { return NextResponse.json({ error: 'Not an image' }, { status: 400 }); } diff --git a/ui/src/app/api/jobs/[jobID]/samples/route.ts b/ui/src/app/api/jobs/[jobID]/samples/route.ts index 2a98a6ea..c0f2ab33 100644 --- a/ui/src/app/api/jobs/[jobID]/samples/route.ts +++ b/ui/src/app/api/jobs/[jobID]/samples/route.ts @@ -29,7 +29,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s const samples = fs .readdirSync(samplesFolder) .filter(file => { - return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp'); + return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp') || file.endsWith('.mp4'); }) .map(file => { return path.join(samplesFolder, file); diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 876b8ffd..82d89690 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -862,6 +862,48 @@ export default function SimpleJob({ docKey="datasets.do_i2v" /> )} + {modelArch?.additionalSections?.includes('datasets.do_audio') && ( + { + if (!value) { + setJobConfig(undefined, `config.process[0].datasets[${i}].do_audio`); + } else { + setJobConfig(value, `config.process[0].datasets[${i}].do_audio`); + } + }} + docKey="datasets.do_audio" + /> + )} + {modelArch?.additionalSections?.includes('datasets.audio_normalize') && ( + { + if (!value) { + setJobConfig(undefined, `config.process[0].datasets[${i}].audio_normalize`); + } else { + setJobConfig(value, `config.process[0].datasets[${i}].audio_normalize`); + } + }} + docKey="datasets.audio_normalize" + /> + )} + {modelArch?.additionalSections?.includes('datasets.audio_preserve_pitch') && ( + { + if (!value) { + setJobConfig(undefined, `config.process[0].datasets[${i}].audio_preserve_pitch`); + } else { + setJobConfig(value, `config.process[0].datasets[${i}].audio_preserve_pitch`); + } + }} + docKey="datasets.audio_preserve_pitch" + /> + )} { // Sort by label, case-insensitive return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }); diff --git a/ui/src/app/jobs/new/utils.ts b/ui/src/app/jobs/new/utils.ts index e216333a..e69b912b 100644 --- a/ui/src/app/jobs/new/utils.ts +++ b/ui/src/app/jobs/new/utils.ts @@ -2,6 +2,25 @@ import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; import { modelArchs, ModelArch } from './options'; import { objectCopy } from '@/utils/basic'; +const expandDatasetDefaults = ( + defaults: { [key: string]: any }, + numDatasets: number, +): { [key: string]: any } => { + // expands the defaults for datasets[x] to datasets[0], datasets[1], etc. + const expandedDefaults: { [key: string]: any } = { ...defaults }; + for (const key in defaults) { + if (key.includes('datasets[x].')) { + for (let i = 0; i < numDatasets; i++) { + const datasetKey = key.replace('datasets[x].', `datasets[${i}].`); + const v = defaults[key]; + expandedDefaults[datasetKey] = Array.isArray(v) ? [...v] : objectCopy(v); + } + delete expandedDefaults[key]; + } + } + return expandedDefaults; +}; + export const handleModelArchChange = ( currentArchName: string, newArchName: string, @@ -39,16 +58,11 @@ export const handleModelArchChange = ( } } - // revert defaults from previous model - for (const key in currentArch.defaults) { - setJobConfig(currentArch.defaults[key][1], key); - } + const numDatasets = jobConfig.config.process[0].datasets.length; + + let currentDefaults = expandDatasetDefaults(currentArch.defaults || {}, numDatasets); + let newDefaults = expandDatasetDefaults(newArch?.defaults || {}, numDatasets); - if (newArch?.defaults) { - for (const key in newArch.defaults) { - setJobConfig(newArch.defaults[key][0], key); - } - } // set new model setJobConfig(newArchName, 'config.process[0].model.arch'); @@ -79,27 +93,27 @@ export const handleModelArchChange = ( if (newDataset.control_path_1 && newDataset.control_path_1 !== '') { newDataset.control_path = newDataset.control_path_1; } - if (newDataset.control_path_1) { + if ('control_path_1' in newDataset) { delete newDataset.control_path_1; } - if (newDataset.control_path_2) { + if ('control_path_2' in newDataset) { delete newDataset.control_path_2; } - if (newDataset.control_path_3) { + if ('control_path_3' in newDataset) { delete newDataset.control_path_3; } } else { // does not have control images - if (newDataset.control_path) { + if ('control_path' in newDataset) { delete newDataset.control_path; } - if (newDataset.control_path_1) { + if ('control_path_1' in newDataset) { delete newDataset.control_path_1; } - if (newDataset.control_path_2) { + if ('control_path_2' in newDataset) { delete newDataset.control_path_2; } - if (newDataset.control_path_3) { + if ('control_path_3' in newDataset) { delete newDataset.control_path_3; } } @@ -120,4 +134,13 @@ export const handleModelArchChange = ( return newSample; }); setJobConfig(samples, 'config.process[0].sample.samples'); + + // revert defaults from previous model + for (const key in currentDefaults) { + setJobConfig(currentDefaults[key][1], key); + } + + for (const key in newDefaults) { + setJobConfig(newDefaults[key][0], key); + } }; diff --git a/ui/src/components/SampleImageCard.tsx b/ui/src/components/SampleImageCard.tsx index 7f01bc8f..dc1c213e 100644 --- a/ui/src/components/SampleImageCard.tsx +++ b/ui/src/components/SampleImageCard.tsx @@ -56,18 +56,6 @@ const SampleImageCard: React.FC = ({ return () => observer.disconnect(); }, [observerRoot, rootMargin]); - // Pause video when leaving viewport - useEffect(() => { - if (!isVideo(imageUrl)) return; - const v = videoRef.current; - if (!v) return; - if (!isVisible && !v.paused) { - try { - v.pause(); - } catch {} - } - }, [isVisible, imageUrl]); - const handleLoad = () => setLoaded(true); return ( @@ -81,9 +69,11 @@ const SampleImageCard: React.FC = ({ src={`/api/img/${encodeURIComponent(imageUrl)}`} className="w-full h-full object-cover" preload="none" + onLoad={handleLoad} playsInline muted loop + autoPlay controls={false} /> ) : ( diff --git a/ui/src/components/SampleImageViewer.tsx b/ui/src/components/SampleImageViewer.tsx index 3e0613d8..e8804a9a 100644 --- a/ui/src/components/SampleImageViewer.tsx +++ b/ui/src/components/SampleImageViewer.tsx @@ -7,6 +7,7 @@ import { Cog } from 'lucide-react'; import { Menu, MenuButton, MenuItem, MenuItems } from '@headlessui/react'; import { openConfirm } from './ConfirmModal'; import { apiClient } from '@/utils/api'; +import { isVideo } from '@/utils/basic'; interface Props { imgPath: string | null; // current image path @@ -200,13 +201,24 @@ export default function SampleImageViewer({ className="relative transform rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in max-w-[95%] max-h-[95vh] data-closed:sm:translate-y-0 data-closed:sm:scale-95 flex flex-col overflow-hidden" >
- {imgPath && ( - Sample Image - )} + {imgPath && + (isVideo(imgPath) ? ( +
{/* # make full width */}
diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index 291d9ee3..82fc0d1b 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -107,6 +107,35 @@ const docs: { [key: string]: ConfigDoc } = { ), }, + 'datasets.do_audio': { + title: 'Do Audio', + description: ( + <> + For models that support audio with video, this option will load the audio from the video and resize it to match + the video sequence. Since the video is automatically resized, the audio may drop or raise in pitch to match the new + speed of the video. It is important to prep your dataset to have the proper length before training. + + ), + }, + 'datasets.audio_normalize': { + title: 'Audio Normalize', + description: ( + <> + When loading audio, this will normalize the audio volume to the max peaks. Useful if your dataset has varying audio + volumes. Warning, do not use if you have clips with full silence you want to keep, as it will raise the volume of those clips. + + ), + }, + 'datasets.audio_preserve_pitch': { + title: 'Audio Preserve Pitch', + description: ( + <> + When loading audio to match the number of frames requested, this option will preserve the pitch of the audio if + the length does not match training target. It is recommended to have a dataset that matches your target length, + as this option can add sound distortions. + + ), + }, 'datasets.flip': { title: 'Flip X and Flip Y', description: ( diff --git a/ui/src/types.ts b/ui/src/types.ts index 7844ce09..be4af898 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -96,7 +96,11 @@ export interface DatasetConfig { control_path?: string | null; num_frames: number; shrink_video_to_frames: boolean; - do_i2v: boolean; + do_i2v?: boolean; + do_audio?: boolean; + audio_normalize?: boolean; + audio_preserve_pitch?: boolean; + fps?: number; flip_x: boolean; flip_y: boolean; control_path_1?: string | null; diff --git a/ui/src/utils/hooks.tsx b/ui/src/utils/hooks.tsx index f96af344..3c3cac15 100644 --- a/ui/src/utils/hooks.tsx +++ b/ui/src/utils/hooks.tsx @@ -17,21 +17,14 @@ export function setNestedValue(obj: T, value: V, path?: string): T { } // Split the path into segments - const pathArray = path.split('.').flatMap(segment => { - // Handle array notation like 'process[0]' - const arrayMatch = segment.match(/^([^\[]+)(\[\d+\])+/); - if (arrayMatch) { - const propName = arrayMatch[1]; - const indices = segment - .substring(propName.length) - .match(/\[(\d+)\]/g) - ?.map(idx => parseInt(idx.substring(1, idx.length - 1))); + const pathArray: Array = []; + const re = /([^[.\]]+)|\[(\d+)\]/g; + let m: RegExpExecArray | null; - // Return property name followed by array indices - return [propName, ...(indices || [])]; - } - return segment; - }); + while ((m = re.exec(path)) !== null) { + if (m[1] !== undefined) pathArray.push(m[1]); + else pathArray.push(Number(m[2])); + } // Navigate to the target location let current: any = result; @@ -43,8 +36,18 @@ export function setNestedValue(obj: T, value: V, path?: string): T { if (!Array.isArray(current)) { throw new Error(`Cannot access index ${key} of non-array`); } - // Create a copy of the array to maintain immutability - current = [...current]; + + // Ensure the indexed element exists and is copied/created immutably + const nextKey = pathArray[i + 1]; + const existing = current[key]; + + if (existing === undefined) { + current[key] = typeof nextKey === 'number' ? [] : {}; + } else if (Array.isArray(existing)) { + current[key] = [...existing]; + } else if (typeof existing === 'object' && existing !== null) { + current[key] = { ...existing }; + } // else: primitives stay as-is } else { // For object properties, create a new object if it doesn't exist if (current[key] === undefined) { @@ -63,7 +66,11 @@ export function setNestedValue(obj: T, value: V, path?: string): T { // Set the value at the final path segment const finalKey = pathArray[pathArray.length - 1]; - current[finalKey] = value; + if (value === undefined) { + delete current[finalKey]; + } else { + current[finalKey] = value; + } return result; } diff --git a/version.py b/version.py index f9af3954..3db333f0 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.7.16" +VERSION = "0.7.17"