Add LTX-2 Support (#644)

* WIP, adding support for LTX2

* Training on images working

* Fix loading comfy models

* Handle converting and deconverting lora so it matches original format

* Reworked ui to habdle ltx and propert dataset default overwriting.

* Update the way lokr saves to it is more compatable with comfy

* Audio loading and synchronization/resampling is working

* Add audio to training. Does it work? Maybe, still testing.

* Fixed fps default issue for sound

* Have ui set fps for accurate audio mapping on ltx

* Added audio procession options to the ui for ltx

* Clean up requirements
This commit is contained in:
Jaret Burkett
2026-01-13 04:55:30 -07:00
committed by GitHub
parent 6870ab490f
commit 5b5aadadb8
28 changed files with 2180 additions and 71 deletions

View File

@@ -7,6 +7,7 @@ from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
from .flux2 import Flux2Model
from .z_image import ZImageModel
from .ltx2 import LTX2Model
AI_TOOLKIT_MODELS = [
# put a list of models here
@@ -25,4 +26,5 @@ AI_TOOLKIT_MODELS = [
QwenImageEditPlusModel,
Flux2Model,
ZImageModel,
LTX2Model,
]

View File

@@ -0,0 +1 @@
from .ltx2 import LTX2Model

View File

@@ -0,0 +1,648 @@
# ref https://github.com/huggingface/diffusers/blob/249ae1f853be8775c7d0b3a26ca4fbc4f6aa9998/scripts/convert_ltx2_to_diffusers.py
from contextlib import nullcontext
from typing import Any, Dict, Tuple
import torch
from accelerate import init_empty_weights
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available() else nullcontext
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulatin Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_VOCODER_RENAME_DICT = {
"ups": "upsamplers",
"resblocks": "resnets",
"conv_pre": "conv_in",
"conv_post": "conv_out",
}
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(
state_dict: Dict[str, Any], old_key: str, new_key: str
) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
state_dict.pop(key)
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_ltx2_audio_vae_per_channel_statistics(
key: str, state_dict: Dict[str, Any]
) -> None:
if key.startswith("per_channel_statistics"):
new_key = ".".join(["decoder", key])
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
def split_transformer_and_connector_state_dict(
state_dict: Dict[str, Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
connector_prefixes = (
"video_embeddings_connector",
"audio_embeddings_connector",
"transformer_1d_blocks",
"text_embedding_projection.aggregate_embed",
"connectors.",
"video_connector",
"audio_connector",
"text_proj_in",
)
transformer_state_dict, connector_state_dict = {}, {}
for key, value in state_dict.items():
if key.startswith(connector_prefixes):
connector_state_dict[key] = value
else:
transformer_state_dict[key] = value
return transformer_state_dict, connector_state_dict
def get_ltx2_transformer_config() -> Tuple[
Dict[str, Any], Dict[str, Any], Dict[str, Any]
]:
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 128,
"audio_out_channels": 128,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 32,
"audio_attention_head_dim": 64,
"audio_cross_attention_dim": 2048,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 3840,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1000,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def get_ltx2_connectors_config() -> Tuple[
Dict[str, Any], Dict[str, Any], Dict[str, Any]
]:
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
return config, rename_dict, special_keys_remap
def convert_ltx2_transformer(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_transformer_config()
diffusers_config = config["diffusers_config"]
transformer_state_dict, _ = split_transformer_and_connector_state_dict(
original_state_dict
)
with init_empty_weights():
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(transformer_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(transformer_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(transformer_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, transformer_state_dict)
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
return transformer
def convert_ltx2_connectors(original_state_dict: Dict[str, Any]) -> LTX2TextConnectors:
config, rename_dict, special_keys_remap = get_ltx2_connectors_config()
diffusers_config = config["diffusers_config"]
_, connector_state_dict = split_transformer_and_connector_state_dict(
original_state_dict
)
if len(connector_state_dict) == 0:
raise ValueError("No connector weights found in the provided state dict.")
with init_empty_weights():
connectors = LTX2TextConnectors.from_config(diffusers_config)
for key in list(connector_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(connector_state_dict, key, new_key)
for key in list(connector_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, connector_state_dict)
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
return connectors
def get_ltx2_video_vae_config() -> Tuple[
Dict[str, Any], Dict[str, Any], Dict[str, Any]
]:
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": (
"spatial",
"temporal",
"spatiotemporal",
"spatiotemporal",
),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config()
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_audio_vae_config() -> Tuple[
Dict[str, Any], Dict[str, Any], Dict[str, Any]
]:
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"base_channels": 128,
"output_channels": 2,
"ch_mult": (1, 2, 4),
"num_res_blocks": 2,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 256,
"latent_channels": 8,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 64,
"double_z": True,
},
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config()
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_vocoder_config() -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"hidden_channels": 1024,
"out_channels": 2,
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_factors": [6, 5, 2, 2, 2],
"resnet_kernel_sizes": [3, 7, 11],
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 24000,
},
}
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config()
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vocoder = LTX2Vocoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
return vocoder
def get_model_state_dict_from_combined_ckpt(
combined_ckpt: Dict[str, Any], prefix: str
) -> Dict[str, Any]:
# Ensure that the key prefix ends with a dot (.)
if not prefix.endswith("."):
prefix = prefix + "."
model_state_dict = {}
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
model_state_dict[param_name.replace(prefix, "")] = param
if prefix == "model.diffusion_model.":
# Some checkpoints store the text connector projection outside the diffusion model prefix.
connector_key = "text_embedding_projection.aggregate_embed.weight"
if connector_key in combined_ckpt and connector_key not in model_state_dict:
model_state_dict[connector_key] = combined_ckpt[connector_key]
return model_state_dict
def dequantize_state_dict(state_dict: Dict[str, Any]):
keys = list(state_dict.keys())
state_out = {}
for k in keys:
if k.endswith(
(".weight_scale", ".weight_scale_2", ".pre_quant_scale", ".input_scale")
):
continue
t = state_dict[k]
if k.endswith(".weight"):
prefix = k[: -len(".weight")]
wscale_k = prefix + ".weight_scale"
if wscale_k in state_dict:
w_q = t
w_scale = state_dict[wscale_k]
# Comfy quant = absmax per-tensor weight quant, nothing fancy
w_bf16 = w_q.to(torch.bfloat16) * w_scale.to(torch.bfloat16)
state_out[k] = w_bf16
continue
state_out[k] = t
return state_out
def convert_comfy_gemma3_to_transformers(sd: dict):
out = {}
sd = dequantize_state_dict(sd)
for k, v in sd.items():
nk = k
# Vision tower weights: checkpoint has "vision_model.*"
# model expects "model.vision_tower.vision_model.*"
if k.startswith("vision_model."):
nk = "model.vision_tower." + k
# MM projector: checkpoint has "multi_modal_projector.*"
# model expects "model.multi_modal_projector.*"
elif k.startswith("multi_modal_projector."):
nk = "model." + k
# Language model: checkpoint has "model.layers.*", "model.embed_tokens.*", "model.norm.*"
# model expects "model.language_model.layers.*", etc.
elif k == "model.embed_tokens.weight":
nk = "model.language_model.embed_tokens.weight"
elif k.startswith("model.layers."):
nk = "model.language_model.layers." + k[len("model.layers.") :]
elif k.startswith("model.norm."):
nk = "model.language_model.norm." + k[len("model.norm.") :]
# (optional) common DDP prefix
if nk.startswith("module."):
nk = nk[len("module.") :]
# skip spiece_model
if nk == "spiece_model":
continue
out[nk] = v
# If lm_head is missing but embeddings exist, many Gemma-family models tie these weights.
# Add it so strict loading won't complain (or just load strict=False and call tie_weights()).
if (
"lm_head.weight" not in out
and "model.language_model.embed_tokens.weight" in out
):
out["lm_head.weight"] = out["model.language_model.embed_tokens.weight"]
return out
def convert_lora_original_to_diffusers(
lora_state_dict: Dict[str, Any],
) -> Dict[str, Any]:
out: Dict[str, Any] = {}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
for k, v in lora_state_dict.items():
# Keep the "diffusion_model." prefix as-is, but apply the transformer remaps to the rest
prefix = ""
rest = k
if rest.startswith("diffusion_model."):
prefix = "diffusion_model."
rest = rest[len(prefix) :]
nk = rest
# Same simple 1:1 remaps as the transformer
for replace_key, rename_key in rename_dict.items():
nk = nk.replace(replace_key, rename_key)
# Same special-case remap as the transformer (applies to LoRA keys too)
if nk.startswith("adaln_single."):
nk = nk.replace("adaln_single.", "time_embed.", 1)
elif nk.startswith("audio_adaln_single."):
nk = nk.replace("audio_adaln_single.", "audio_time_embed.", 1)
out[prefix + nk] = v
return out
def convert_lora_diffusers_to_original(
lora_state_dict: Dict[str, Any],
) -> Dict[str, Any]:
out: Dict[str, Any] = {}
inv_rename = {v: k for k, v in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items()}
inv_items = sorted(inv_rename.items(), key=lambda kv: len(kv[0]), reverse=True)
for k, v in lora_state_dict.items():
# Keep the "diffusion_model." prefix as-is, but invert remaps on the rest
prefix = ""
rest = k
if rest.startswith("diffusion_model."):
prefix = "diffusion_model."
rest = rest[len(prefix) :]
nk = rest
# Inverse of the adaln_single special-case
if nk.startswith("time_embed."):
nk = nk.replace("time_embed.", "adaln_single.", 1)
elif nk.startswith("audio_time_embed."):
nk = nk.replace("audio_time_embed.", "audio_adaln_single.", 1)
# Inverse 1:1 remaps
for diffusers_key, original_key in inv_items:
nk = nk.replace(diffusers_key, original_key)
out[prefix + nk] = v
return out

View File

@@ -0,0 +1,852 @@
from functools import partial
import os
from typing import List, Optional
import torch
import torchaudio
from transformers import Gemma3Config
import yaml
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from accelerate import init_empty_weights
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze
from toolkit.util.quantize import quantize, get_qtype, quantize_model
from toolkit.memory_management import MemoryManager
from safetensors.torch import load_file
try:
from diffusers import LTX2Pipeline
from diffusers.models.autoencoders import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
)
from diffusers.models.transformers import LTX2VideoTransformer3DModel
from diffusers.pipelines.ltx2.export_utils import encode_video
from transformers import (
Gemma3ForConditionalGeneration,
GemmaTokenizerFast,
)
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
from .convert_ltx2_to_diffusers import (
get_model_state_dict_from_combined_ckpt,
convert_ltx2_transformer,
convert_ltx2_video_vae,
convert_ltx2_audio_vae,
convert_ltx2_vocoder,
convert_ltx2_connectors,
dequantize_state_dict,
convert_comfy_gemma3_to_transformers,
convert_lora_original_to_diffusers,
convert_lora_diffusers_to_original,
)
except ImportError as e:
print("Diffusers import error:", e)
raise ImportError(
"Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt"
)
scheduler_config = {
"base_image_seq_len": 1024,
"base_shift": 0.95,
"invert_sigmas": False,
"max_image_seq_len": 4096,
"max_shift": 2.05,
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": 0.1,
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
dit_prefix = "model.diffusion_model."
vae_prefix = "vae."
audio_vae_prefix = "audio_vae."
vocoder_prefix = "vocoder."
def new_save_image_function(
self: GenerateImageConfig,
image, # will contain a dict that can be dumped ditectly into encode_video, just add output_path to it.
count: int = 0,
max_count: int = 0,
**kwargs,
):
# this replaces gen image config save image function so we can save the video with sound from ltx2
image["output_path"] = self.get_image_path(count, max_count)
# make sample directory if it does not exist
os.makedirs(os.path.dirname(image["output_path"]), exist_ok=True)
encode_video(**image)
flush()
def blank_log_image_function(self, *args, **kwargs):
# todo handle wandb logging of videos with audio
return
class AudioProcessor(torch.nn.Module):
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
def __init__(
self,
sample_rate: int,
mel_bins: int,
mel_hop_length: int,
n_fft: int,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=n_fft,
hop_length=mel_hop_length,
f_min=0.0,
f_max=sample_rate / 2.0,
n_mels=mel_bins,
window_fn=torch.hann_window,
center=True,
pad_mode="reflect",
power=1.0,
mel_scale="slaney",
norm="slaney",
)
def resample_waveform(
self,
waveform: torch.Tensor,
source_rate: int,
target_rate: int,
) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(device=waveform.device, dtype=waveform.dtype)
def waveform_to_mel(
self,
waveform: torch.Tensor,
waveform_sample_rate: int,
) -> torch.Tensor:
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
waveform = self.resample_waveform(
waveform, waveform_sample_rate, self.sample_rate
)
mel = self.mel_transform(waveform)
mel = torch.log(torch.clamp(mel, min=1e-5))
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
return mel.permute(0, 1, 3, 2).contiguous()
class LTX2Model(BaseModel):
arch = "ltx2"
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ["LTX2VideoTransformer3DModel"]
# defines if the model supports model paths. Only some will
self.supports_model_paths = True
# use the new format on this new model by default
self.use_old_lokr_format = False
self.audio_processor = None
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
return 32
def load_model(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading LTX2 model")
model_path = self.model_config.name_or_path
base_model_path = self.model_config.extras_name_or_path
combined_state_dict = None
self.print_and_status_update("Loading transformer")
# if we have a safetensors file it is a mono checkpoint
if os.path.exists(model_path) and model_path.endswith(".safetensors"):
combined_state_dict = load_file(model_path)
combined_state_dict = dequantize_state_dict(combined_state_dict)
if combined_state_dict is not None:
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, dit_prefix
)
transformer = convert_ltx2_transformer(original_dit_ckpt)
transformer = transformer.to(dtype)
else:
transformer_path = model_path
transformer_subfolder = "transformer"
if os.path.exists(transformer_path):
transformer_subfolder = None
transformer_path = os.path.join(transformer_path, "transformer")
# check if the path is a full checkpoint.
te_folder_path = os.path.join(model_path, "text_encoder")
# if we have the te, this folder is a full checkpoint, use it as the base
if os.path.exists(te_folder_path):
base_model_path = model_path
transformer = LTX2VideoTransformer3DModel.from_pretrained(
transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype
)
if self.model_config.quantize:
self.print_and_status_update("Quantizing Transformer")
quantize_model(self, transformer)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_transformer_percent > 0
):
ignore_modules = []
for block in transformer.transformer_blocks:
ignore_modules.append(block.scale_shift_table)
ignore_modules.append(block.audio_scale_shift_table)
ignore_modules.append(block.video_a2v_cross_attn_scale_shift_table)
ignore_modules.append(block.audio_a2v_cross_attn_scale_shift_table)
ignore_modules.append(transformer.scale_shift_table)
ignore_modules.append(transformer.audio_scale_shift_table)
MemoryManager.attach(
transformer,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent,
ignore_modules=ignore_modules,
)
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer to CPU")
transformer.to("cpu")
flush()
self.print_and_status_update("Loading text encoder")
if (
self.model_config.te_name_or_path is not None
and self.model_config.te_name_or_path.endswith(".safetensors")
):
# load from comfyui gemma3 checkpoint
tokenizer = GemmaTokenizerFast.from_pretrained(
"Lightricks/LTX-2", subfolder="tokenizer"
)
with init_empty_weights():
text_encoder = Gemma3ForConditionalGeneration(
Gemma3Config(
**{
"boi_token_index": 255999,
"bos_token_id": 2,
"eoi_token_index": 256000,
"eos_token_id": 106,
"image_token_index": 262144,
"initializer_range": 0.02,
"mm_tokens_per_image": 256,
"model_type": "gemma3",
"pad_token_id": 0,
"text_config": {
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": None,
"cache_implementation": "hybrid",
"final_logit_softcapping": None,
"head_dim": 256,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 3840,
"initializer_range": 0.02,
"intermediate_size": 15360,
"max_position_embeddings": 131072,
"model_type": "gemma3_text",
"num_attention_heads": 16,
"num_hidden_layers": 48,
"num_key_value_heads": 8,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000,
"rope_scaling": {"factor": 8.0, "rope_type": "linear"},
"rope_theta": 1000000,
"sliding_window": 1024,
"sliding_window_pattern": 6,
"torch_dtype": "bfloat16",
"use_cache": True,
"vocab_size": 262208,
},
"torch_dtype": "bfloat16",
"transformers_version": "4.51.3",
"unsloth_fixed": True,
"vision_config": {
"attention_dropout": 0.0,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 896,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
"torch_dtype": "bfloat16",
"vision_use_head": False,
},
}
)
)
te_state_dict = load_file(self.model_config.te_name_or_path)
te_state_dict = convert_comfy_gemma3_to_transformers(te_state_dict)
for key in te_state_dict:
te_state_dict[key] = te_state_dict[key].to(dtype)
text_encoder.load_state_dict(te_state_dict, assign=True, strict=True)
del te_state_dict
flush()
else:
if self.model_config.te_name_or_path is not None:
te_path = self.model_config.te_name_or_path
else:
te_path = base_model_path
tokenizer = GemmaTokenizerFast.from_pretrained(
te_path, subfolder="tokenizer"
)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
te_path, subfolder="text_encoder", dtype=dtype
)
# remove the vision tower
text_encoder.model.vision_tower = None
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
ignore_modules=[
text_encoder.model.language_model.base_model.embed_tokens
],
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Text Encoder")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
freeze(text_encoder)
flush()
self.print_and_status_update("Loading VAEs and other components")
if combined_state_dict is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, vae_prefix
)
vae = convert_ltx2_video_vae(original_vae_ckpt).to(dtype)
del original_vae_ckpt
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, audio_vae_prefix
)
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt).to(dtype)
del original_audio_vae_ckpt
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, dit_prefix
)
connectors = convert_ltx2_connectors(original_connectors_ckpt).to(dtype)
del original_connectors_ckpt
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, vocoder_prefix
)
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt).to(dtype)
del original_vocoder_ckpt
del combined_state_dict
flush()
else:
vae = AutoencoderKLLTX2Video.from_pretrained(
base_model_path, subfolder="vae", torch_dtype=dtype
)
audio_vae = AutoencoderKLLTX2Audio.from_pretrained(
base_model_path, subfolder="audio_vae", torch_dtype=dtype
)
connectors = LTX2TextConnectors.from_pretrained(
base_model_path, subfolder="connectors", torch_dtype=dtype
)
vocoder = LTX2Vocoder.from_pretrained(
base_model_path, subfolder="vocoder", torch_dtype=dtype
)
self.noise_scheduler = LTX2Model.get_train_scheduler()
self.print_and_status_update("Making pipe")
pipe: LTX2Pipeline = LTX2Pipeline(
scheduler=self.noise_scheduler,
vae=vae,
audio_vae=audio_vae,
text_encoder=None,
tokenizer=tokenizer,
connectors=connectors,
transformer=None,
vocoder=vocoder,
)
# for quantization, it works best to do these after making the pipe
pipe.text_encoder = text_encoder
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder]
tokenizer = [pipe.tokenizer]
# leave it on cpu for now
if not self.low_vram:
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# just to make sure everything is on the right device and dtype
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.audio_processor = AudioProcessor(
sample_rate=pipe.audio_sampling_rate,
mel_bins=audio_vae.config.mel_bins,
mel_hop_length=pipe.audio_hop_length,
n_fft=1024, # todo get this from vae if we can, I couldnt find it.
).to(self.device_torch, dtype=torch.float32)
self.print_and_status_update("Model Loaded")
@torch.no_grad()
def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None):
if device is None:
device = self.vae_device_torch
if dtype is None:
dtype = self.vae_torch_dtype
if self.vae.device == torch.device("cpu"):
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
image_list = [image.to(device, dtype=dtype) for image in image_list]
# Normalize shapes
norm_images = []
for image in image_list:
if image.ndim == 3:
# (C, H, W) -> (C, 1, H, W)
norm_images.append(image.unsqueeze(1))
elif image.ndim == 4:
# (T, C, H, W) -> (C, T, H, W)
norm_images.append(image.permute(1, 0, 2, 3))
else:
raise ValueError(f"Invalid image shape: {image.shape}")
# Stack to (B, C, T, H, W)
images = torch.stack(norm_images)
latents = self.vae.encode(images).latent_dist.mode()
# Normalize latents across the channel dimension [B, C, F, H, W]
scaling_factor = 1.0
latents_mean = self.pipeline.vae.latents_mean.view(1, -1, 1, 1, 1).to(
latents.device, latents.dtype
)
latents_std = self.pipeline.vae.latents_std.view(1, -1, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents.to(device, dtype=dtype)
def get_generation_pipeline(self):
scheduler = LTX2Model.get_train_scheduler()
pipeline: LTX2Pipeline = LTX2Pipeline(
scheduler=scheduler,
vae=unwrap_model(self.pipeline.vae),
audio_vae=unwrap_model(self.pipeline.audio_vae),
text_encoder=None,
tokenizer=unwrap_model(self.pipeline.tokenizer),
connectors=unwrap_model(self.pipeline.connectors),
transformer=None,
vocoder=unwrap_model(self.pipeline.vocoder),
)
pipeline.transformer = unwrap_model(self.model)
pipeline.text_encoder = unwrap_model(self.text_encoder[0])
# if self.low_vram:
# pipeline.enable_model_cpu_offload(device=self.device_torch)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: LTX2Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch)
is_video = gen_config.num_frames > 1
# override the generate single image to handle video + audio generation
if is_video:
gen_config._orig_save_image_function = gen_config.save_image
gen_config.save_image = partial(new_save_image_function, gen_config)
gen_config.log_image = partial(blank_log_image_function, gen_config)
# set output extension to mp4
gen_config.output_ext = "mp4"
# reactivate progress bar since this is slooooow
pipeline.set_progress_bar_config(disable=False)
pipeline = pipeline.to(self.device_torch)
# make sure dimensions are valid
bd = self.get_bucket_divisibility()
gen_config.height = (gen_config.height // bd) * bd
gen_config.width = (gen_config.width // bd) * bd
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
if gen_config.num_frames != 1:
if (gen_config.num_frames - 1) % 8 != 0:
gen_config.num_frames = ((gen_config.num_frames - 1) // 8) * 8 + 1
if self.low_vram:
# set vae to tile decode
pipeline.vae.enable_tiling(
tile_sample_min_height=256,
tile_sample_min_width=256,
tile_sample_min_num_frames=8,
tile_sample_stride_height=224,
tile_sample_stride_width=224,
tile_sample_stride_num_frames=4,
)
video, audio = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype
),
prompt_attention_mask=conditional_embeds.attention_mask.to(
self.device_torch
),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype
),
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(
self.device_torch
),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
num_frames=gen_config.num_frames,
generator=generator,
return_dict=False,
output_type="np" if is_video else "pil",
**extra,
)
if self.low_vram:
# Restore no tiling
pipeline.vae.use_tiling = False
if is_video:
# redurn as a dict, we will handle it with an override function
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
return {
"video": video[0],
"fps": gen_config.fps,
"audio": audio[0].float().cpu(),
"audio_sample_rate": pipeline.vocoder.config.output_sampling_rate, # should be 24000
"output_path": None,
}
else:
# shape = [1, frames, channels, height, width]
# make sure this is right
video = video[0] # list of pil images
audio = audio[0] # tensor
if gen_config.num_frames > 1:
return video # return the frames.
else:
# get just the first image
img = video[0]
return img
def encode_audio(self, batch: "DataLoaderBatchDTO"):
if self.pipeline.audio_vae.device == torch.device("cpu"):
self.pipeline.audio_vae.to(self.device_torch)
output_tensor = None
audio_num_frames = None
# do them seperatly for now
for audio_data in batch.audio_data:
waveform = audio_data["waveform"].to(
device=self.device_torch, dtype=torch.float32
)
sample_rate = audio_data["sample_rate"]
# Add batch dimension if needed: [channels, samples] -> [batch, channels, samples]
if waveform.dim() == 2:
waveform = waveform.unsqueeze(0)
if waveform.shape[1] == 1:
# make sure it is stereo
waveform = waveform.repeat(1, 2, 1)
# Convert waveform to mel spectrogram using AudioProcessor
mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate)
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
# Encode mel spectrogram to latents
latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode()
if audio_num_frames is None:
audio_num_frames = latents.shape[2] #(latents is [B, C, T, F])
packed_latents = self.pipeline._pack_audio_latents(
latents,
# patch_size=self.pipeline.transformer.config.audio_patch_size,
# patch_size_t=self.pipeline.transformer.config.audio_patch_size_t,
) # [B, L, C * M]
if output_tensor is None:
output_tensor = packed_latents
else:
output_tensor = torch.cat([output_tensor, packed_latents], dim=0)
# normalize latents, opposite of (latents * latents_std) + latents_mean
latents_mean = self.pipeline.audio_vae.latents_mean
latents_std = self.pipeline.audio_vae.latents_std
output_tensor = (output_tensor - latents_mean) / latents_std
return output_tensor, audio_num_frames
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
batch: "DataLoaderBatchDTO" = None,
**kwargs,
):
with torch.no_grad():
if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch)
batch_size, C, latent_num_frames, latent_height, latent_width = (
latent_model_input.shape
)
# todo get this somehow
frame_rate = 24
# check frame dimension
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
packed_latents = self.pipeline._pack_latents(
latent_model_input,
patch_size=self.pipeline.transformer_spatial_patch_size,
patch_size_t=self.pipeline.transformer_temporal_patch_size,
)
if batch.audio_tensor is not None:
# use audio from the batch if available
#(1, 190, 128)
raw_audio_latents, audio_num_frames = self.encode_audio(batch)
# add the audio targets to the batch for loss calculation later
audio_noise = torch.randn_like(raw_audio_latents)
batch.audio_target = (audio_noise - raw_audio_latents).detach()
audio_latents = self.add_noise(
raw_audio_latents,
audio_noise,
timestep,
).to(self.device_torch, dtype=self.torch_dtype)
else:
# no audio
num_mel_bins = self.pipeline.audio_vae.config.mel_bins
# latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
num_channels_latents_audio = (
self.pipeline.audio_vae.config.latent_channels
)
# audio latents are (1, 126, 128), audio_num_frames = 126
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
batch_size,
num_channels_latents=num_channels_latents_audio,
num_mel_bins=num_mel_bins,
num_frames=batch.tensor.shape[1],
frame_rate=frame_rate,
sampling_rate=self.pipeline.audio_sampling_rate,
hop_length=self.pipeline.audio_hop_length,
dtype=torch.float32,
device=self.transformer.device,
generator=None,
latents=None,
)
if self.pipeline.connectors.device != self.transformer.device:
self.pipeline.connectors.to(self.transformer.device)
# TODO this is how diffusers does this on inference, not sure I understand why, check this
additive_attention_mask = (
1 - text_embeddings.attention_mask.to(self.transformer.dtype)
) * -1000000.0
(
connector_prompt_embeds,
connector_audio_prompt_embeds,
connector_attention_mask,
) = self.pipeline.connectors(
text_embeddings.text_embeds, additive_attention_mask, additive_mask=True
)
# compute video and audio positional ids
video_coords = self.transformer.rope.prepare_video_coords(
packed_latents.shape[0],
latent_num_frames,
latent_height,
latent_width,
packed_latents.device,
fps=frame_rate,
)
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
audio_latents.shape[0], audio_num_frames, audio_latents.device
)
noise_pred_video, noise_pred_audio = self.transformer(
hidden_states=packed_latents,
audio_hidden_states=audio_latents.to(self.transformer.dtype),
encoder_hidden_states=connector_prompt_embeds,
audio_encoder_hidden_states=connector_audio_prompt_embeds,
timestep=timestep,
encoder_attention_mask=connector_attention_mask,
audio_encoder_attention_mask=connector_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
fps=frame_rate,
audio_num_frames=audio_num_frames,
video_coords=video_coords,
audio_coords=audio_coords,
# rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=None,
return_dict=False,
)
# add audio latent to batch if we had audio
if batch.audio_target is not None:
batch.audio_pred = noise_pred_audio
unpacked_output = self.pipeline._unpack_latents(
latents=noise_pred_video,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
patch_size=self.pipeline.transformer_spatial_patch_size,
patch_size_t=self.pipeline.transformer_temporal_patch_size,
)
return unpacked_output
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
prompt,
do_classifier_free_guidance=False,
device=self.device_torch,
)
pe = PromptEmbeds([prompt_embeds, None])
pe.attention_mask = prompt_attention_mask
return pe
def get_model_has_grad(self):
return False
def get_te_has_grad(self):
return False
def save_model(self, output_path, meta, save_dtype):
transformer: LTX2VideoTransformer3DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, "transformer"),
safe_serialization=True,
)
meta_path = os.path.join(output_path, "aitk_meta.yaml")
with open(meta_path, "w") as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get("noise")
batch = kwargs.get("batch")
return (noise - batch.latents).detach()
def get_base_model_version(self):
return "ltx2"
def get_transformer_block_names(self) -> Optional[List[str]]:
return ["transformer_blocks"]
def convert_lora_weights_before_save(self, state_dict):
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("transformer.", "diffusion_model.")
new_sd[new_key] = value
new_sd = convert_lora_diffusers_to_original(new_sd)
return new_sd
def convert_lora_weights_before_load(self, state_dict):
state_dict = convert_lora_original_to_diffusers(state_dict)
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd

View File

@@ -859,6 +859,11 @@ class SDTrainer(BaseSDTrainProcess):
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
# check for audio loss
if batch.audio_pred is not None and batch.audio_target is not None:
audio_loss = torch.nn.functional.mse_loss(batch.audio_pred.float(), batch.audio_target.float(), reduction="mean")
loss = loss + audio_loss
# check for additional losses
if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None:

View File

@@ -1,6 +1,6 @@
torchao==0.10.0
safetensors
git+https://github.com/huggingface/diffusers@f6b6a7181eb44f0120b29cd897c129275f366c2a
git+https://github.com/huggingface/diffusers@8600b4c10d67b0ce200f664204358747bd53c775
transformers==4.57.3
lycoris-lora==1.8.3
flatten_json
@@ -36,4 +36,6 @@ opencv-python
pytorch-wavelets==1.3.0
matplotlib==3.10.1
setuptools==69.5.1
scipy==1.12.0
scipy==1.12.0
av==16.0.1
torchcodec

View File

@@ -0,0 +1,234 @@
import time
from torch.utils.data import DataLoader
import sys
import os
import argparse
from tqdm import tqdm
import torch
from torchvision.io import write_video
import subprocess
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.config_modules import DatasetConfig
parser = argparse.ArgumentParser()
# parser.add_argument('dataset_folder', type=str, default='input')
parser.add_argument('dataset_folder', type=str)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--num_frames', type=int, default=121)
parser.add_argument('--output_path', type=str, default='output/dataset_test')
args = parser.parse_args()
if args.output_path is None:
raise ValueError('output_path is required for this test script')
if args.output_path is not None:
args.output_path = os.path.abspath(args.output_path)
os.makedirs(args.output_path, exist_ok=True)
dataset_folder = args.dataset_folder
resolution = 512
bucket_tolerance = 64
batch_size = 1
frame_rate = 24
## make fake sd
class FakeSD:
def __init__(self):
self.use_raw_control_images = False
def encode_control_in_text_embeddings(self, *args, **kwargs):
return None
def get_bucket_divisibility(self):
return 32
dataset_config = DatasetConfig(
dataset_path=dataset_folder,
resolution=resolution,
default_caption='default',
buckets=True,
bucket_tolerance=bucket_tolerance,
shrink_video_to_frames=True,
num_frames=args.num_frames,
do_i2v=True,
fps=frame_rate,
do_audio=True,
debug=True,
audio_preserve_pitch=False,
audio_normalize=True
)
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD())
def _tensor_to_uint8_video(frames_fchw: torch.Tensor) -> torch.Tensor:
"""
frames_fchw: [F, C, H, W] float/uint8
returns: [F, H, W, C] uint8 on CPU
"""
x = frames_fchw.detach()
if x.dtype != torch.uint8:
x = x.to(torch.float32)
# Heuristic: if negatives exist, assume [-1,1] normalization; else assume [0,1]
if torch.isfinite(x).all():
if x.min().item() < 0.0:
x = x * 0.5 + 0.5
x = x.clamp(0.0, 1.0)
x = (x * 255.0).round().to(torch.uint8)
else:
x = x.to(torch.uint8)
# [F,C,H,W] -> [F,H,W,C]
x = x.permute(0, 2, 3, 1).contiguous().cpu()
return x
def _mux_with_ffmpeg(video_in: str, wav_in: str, mp4_out: str):
# Copy video stream, encode audio to AAC, align to shortest
subprocess.run(
[
"ffmpeg",
"-y",
"-hide_banner",
"-loglevel",
"error",
"-i",
video_in,
"-i",
wav_in,
"-c:v",
"copy",
"-c:a",
"aac",
"-shortest",
mp4_out,
],
check=True,
)
# run through an epoch ang check sizes
dataloader_iterator = iter(dataloader)
idx = 0
for epoch in range(args.epochs):
for batch in tqdm(dataloader):
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
frames = 1
if len(img_batch.shape) == 5:
frames = img_batch.shape[1]
batch_size, frames, channels, height, width = img_batch.shape
else:
batch_size, channels, height, width = img_batch.shape
# load audio
audio_tensor = batch.audio_tensor # all file items contatinated on the batch dimension
audio_data = batch.audio_data # list of raw audio data per item in the batch
# llm save the videos here with audio and video as mp4
fps = getattr(dataset_config, "fps", None)
if fps is None or fps <= 0:
fps = 1.0
# Ensure we can iterate items even if batch_size > 1
for b in range(batch_size):
# Get per-item frames as [F,C,H,W]
if len(img_batch.shape) == 5:
frames_fchw = img_batch[b]
else:
# single image: [C,H,W] -> [1,C,H,W]
frames_fchw = img_batch[b].unsqueeze(0)
video_uint8 = _tensor_to_uint8_video(frames_fchw)
out_mp4 = os.path.join(args.output_path, f"{idx:06d}_{b:02d}.mp4")
# Pick audio for this item (prefer audio_data list; fallback to audio_tensor)
item_audio = None
item_sr = None
if isinstance(audio_data, (list, tuple)) and len(audio_data) > b:
ad = audio_data[b]
if isinstance(ad, dict) and ("waveform" in ad) and ("sample_rate" in ad) and ad["waveform"] is not None:
item_audio = ad["waveform"]
item_sr = int(ad["sample_rate"])
elif audio_tensor is not None and torch.is_tensor(audio_tensor):
# audio_tensor expected [B, C, L] (or [C,L] if batch collate differs)
if audio_tensor.dim() == 3 and audio_tensor.shape[0] > b:
item_audio = audio_tensor[b]
elif audio_tensor.dim() == 2 and b == 0:
item_audio = audio_tensor
if item_audio is not None:
# best-effort sample rate from audio_data if present but not per-item dict
if isinstance(audio_data, dict) and "sample_rate" in audio_data:
try:
item_sr = int(audio_data["sample_rate"])
except Exception:
item_sr = None
# Write mp4 (with audio if available) using ffmpeg muxing (torchvision audio muxing is unreliable)
tmp_video = out_mp4 + ".tmp_video.mp4"
tmp_wav = out_mp4 + ".tmp_audio.wav"
try:
# Always write video-only first
write_video(tmp_video, video_uint8, fps=float(fps), video_codec="libx264")
if item_audio is not None and item_sr is not None and item_audio.numel() > 0:
import torchaudio
wav = item_audio.detach()
# torchaudio.save expects [channels, samples]
if wav.dim() == 1:
wav = wav.unsqueeze(0)
torchaudio.save(tmp_wav, wav.cpu().to(torch.float32), int(item_sr))
# Mux to final mp4
_mux_with_ffmpeg(tmp_video, tmp_wav, out_mp4)
else:
# No audio: just move video into place
os.replace(tmp_video, out_mp4)
except Exception as e:
# Best-effort fallback: leave a playable video-only file
try:
if os.path.exists(tmp_video):
os.replace(tmp_video, out_mp4)
else:
write_video(out_mp4, video_uint8, fps=float(fps), video_codec="libx264")
except Exception:
raise
if hasattr(dataset_config, 'debug') and dataset_config.debug:
print(f"Warning: failed to mux audio into mp4 for {out_mp4}: {e}")
finally:
# Cleanup temps (don't leave separate wavs lying around)
try:
if os.path.exists(tmp_video):
os.remove(tmp_video)
except Exception:
pass
try:
if os.path.exists(tmp_wav):
os.remove(tmp_wav)
except Exception:
pass
time.sleep(0.2)
idx += 1
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)
print('done')

View File

@@ -0,0 +1,75 @@
import math
import torch
import torch.nn.functional as F
import torchaudio
def time_stretch_preserve_pitch(waveform: torch.Tensor, sample_rate: int, target_samples: int) -> torch.Tensor:
"""
waveform: [C, L] float tensor (CPU or GPU)
returns: [C, target_samples] float tensor
Pitch-preserving time stretch to match target_samples.
"""
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
waveform = waveform.to(torch.float32)
src_len = waveform.shape[-1]
if src_len == 0 or target_samples <= 0:
return waveform[..., :0]
if src_len == target_samples:
return waveform
# rate > 1.0 speeds up (shorter), rate < 1.0 slows down (longer)
rate = float(src_len) / float(target_samples)
# Use sample_rate to pick STFT params
win_seconds = 0.046
hop_seconds = 0.0115
n_fft_target = int(sample_rate * win_seconds)
n_fft = 1 << max(8, int(math.floor(math.log2(max(256, n_fft_target))))) # >=256, pow2
win_length = n_fft
hop_length = max(64, int(sample_rate * hop_seconds))
hop_length = min(hop_length, win_length // 2)
window = torch.hann_window(win_length, device=waveform.device, dtype=waveform.dtype)
stft = torch.stft(
waveform,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
return_complex=True,
) # [C, F, T] complex
# IMPORTANT: n_freq must match STFT's frequency bins (n_fft//2 + 1)
stretcher = torchaudio.transforms.TimeStretch(
n_freq=stft.shape[-2],
hop_length=hop_length,
fixed_rate=rate,
).to(waveform.device)
stft_stretched = stretcher(stft) # [C, F, T']
stretched = torch.istft(
stft_stretched,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
length=target_samples,
)
if stretched.shape[-1] > target_samples:
stretched = stretched[..., :target_samples]
elif stretched.shape[-1] < target_samples:
stretched = F.pad(stretched, (0, target_samples - stretched.shape[-1]))
return stretched

View File

@@ -207,6 +207,9 @@ class NetworkConfig:
# -1 automatically finds the largest factor
self.lokr_factor = kwargs.get('lokr_factor', -1)
# Use the old lokr format
self.old_lokr_format = kwargs.get('old_lokr_format', False)
# for multi stage models
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
@@ -672,6 +675,9 @@ class ModelConfig:
# kwargs to pass to the model
self.model_kwargs = kwargs.get("model_kwargs", {})
# model paths for models that support it
self.model_paths = kwargs.get("model_paths", {})
# allow frontend to pass arch with a color like arch:tag
# but remove the tag
if self.arch is not None:
@@ -956,7 +962,7 @@ class DatasetConfig:
# it will select a random start frame and pull the frames at the given fps
# this could have various issues with shorter videos and videos with variable fps
# I recommend trimming your videos to the desired length and using shrink_video_to_frames(default)
self.fps: int = kwargs.get('fps', 16)
self.fps: int = kwargs.get('fps', 24)
# debug the frame count and frame selection. You dont need this. It is for debugging.
self.debug: bool = kwargs.get('debug', False)
@@ -972,6 +978,9 @@ class DatasetConfig:
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable
self.do_audio: bool = kwargs.get('do_audio', False) # load audio from video files for models that support it
self.audio_preserve_pitch: bool = kwargs.get('audio_preserve_pitch', False) # preserve pitch when stretching audio to fit num_frames
self.audio_normalize: bool = kwargs.get('audio_normalize', False) # normalize audio volume levels when loading
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:

View File

@@ -123,9 +123,13 @@ class FileItemDTO(
self.is_reg = self.dataset_config.is_reg
self.prior_reg = self.dataset_config.prior_reg
self.tensor: Union[torch.Tensor, None] = None
self.audio_data = None
self.audio_tensor = None
def cleanup(self):
self.tensor = None
self.audio_data = None
self.audio_tensor = None
self.cleanup_latent()
self.cleanup_text_embedding()
self.cleanup_control()
@@ -154,6 +158,13 @@ class DataLoaderBatchDTO:
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None
self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None
self.audio_tensor: Union[torch.Tensor, None] = None
# just for holding noise and preds during training
self.audio_target: Union[torch.Tensor, None] = None
self.audio_pred: Union[torch.Tensor, None] = None
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
@@ -304,6 +315,21 @@ class DataLoaderBatchDTO:
y.text_embeds = [y.text_embeds]
prompt_embeds_list.append(y)
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
if any([x.audio_tensor is not None for x in self.file_items]):
# find one to use as a base
base_audio_tensor = None
for x in self.file_items:
if x.audio_tensor is not None:
base_audio_tensor = x.audio_tensor
break
audio_tensors = []
for x in self.file_items:
if x.audio_tensor is None:
audio_tensors.append(torch.zeros_like(base_audio_tensor))
else:
audio_tensors.append(x.audio_tensor)
self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors])
except Exception as e:
@@ -336,6 +362,10 @@ class DataLoaderBatchDTO:
del self.latents
del self.tensor
del self.control_tensor
del self.audio_tensor
del self.audio_data
del self.audio_target
del self.audio_pred
for file_item in self.file_items:
file_item.cleanup()

View File

@@ -16,6 +16,7 @@ from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor
from toolkit.audio.preserve_pitch import time_stretch_preserve_pitch
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution
from toolkit.config_modules import ControlTypes
@@ -467,6 +468,8 @@ class ImageProcessingDTOMixin:
if not self.dataset_config.buckets:
raise Exception('Buckets required for video processing')
do_audio = self.dataset_config.do_audio
try:
# Use OpenCV to capture video frames
cap = cv2.VideoCapture(self.path)
@@ -596,6 +599,84 @@ class ImageProcessingDTOMixin:
# Stack frames into tensor [frames, channels, height, width]
self.tensor = torch.stack(frames)
# ------------------------------
# Audio extraction + stretching
# ------------------------------
if do_audio:
# Default to "no audio" unless we successfully extract it
self.audio_data = None
self.audio_tensor = None
try:
import torchaudio
import torch.nn.functional as F
# Compute the time range of the selected frames in the *source* video
# Include the last frame by extending to the next frame boundary.
if video_fps and video_fps > 0 and len(frames_to_extract) > 0:
clip_start_frame = int(frames_to_extract[0])
clip_end_frame = int(frames_to_extract[-1])
clip_start_time = clip_start_frame / float(video_fps)
clip_end_time = (clip_end_frame + 1) / float(video_fps)
source_duration = max(0.0, clip_end_time - clip_start_time)
else:
clip_start_time = 0.0
clip_end_time = 0.0
source_duration = 0.0
# Target duration is how this sampled/stretched clip is interpreted for training
# (i.e. num_frames at the configured dataset FPS).
if hasattr(self.dataset_config, "fps") and self.dataset_config.fps and self.dataset_config.fps > 0:
target_duration = float(self.dataset_config.num_frames) / float(self.dataset_config.fps)
else:
target_duration = source_duration
waveform, sample_rate = torchaudio.load(self.path) # [channels, samples]
if self.dataset_config.audio_normalize:
peak = waveform.abs().amax() # global peak across channels
eps = 1e-9
target_peak = 0.999 # ~ -0.01 dBFS
gain = target_peak / (peak + eps)
waveform = waveform * gain
# Slice to the selected clip region (when we have a meaningful time range)
if source_duration > 0.0:
start_sample = int(round(clip_start_time * sample_rate))
end_sample = int(round(clip_end_time * sample_rate))
start_sample = max(0, min(start_sample, waveform.shape[-1]))
end_sample = max(0, min(end_sample, waveform.shape[-1]))
if end_sample > start_sample:
waveform = waveform[..., start_sample:end_sample]
else:
# No valid audio segment
waveform = None
else:
# If we can't compute a meaningful time range, treat as no-audio
waveform = None
if waveform is not None and waveform.numel() > 0:
target_samples = int(round(target_duration * sample_rate))
if target_samples > 0 and waveform.shape[-1] != target_samples:
# Time-stretch/shrink to match the video clip duration implied by dataset FPS.
if self.dataset_config.audio_preserve_pitch:
waveform = time_stretch_preserve_pitch(waveform, sample_rate, target_samples) # waveform is [C, L]
else:
# Use linear interpolation over the time axis.
wf = waveform.unsqueeze(0) # [1, C, L]
wf = F.interpolate(wf, size=target_samples, mode="linear", align_corners=False)
waveform = wf.squeeze(0) # [C, L]
self.audio_tensor = waveform
self.audio_data = {"waveform": waveform, "sample_rate": int(sample_rate)}
except Exception as e:
# Keep behavior identical for non-audio datasets; for audio datasets, just skip if missing/broken.
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
print_acc(f"Could not extract/stretch audio for {self.path}: {e}")
self.audio_data = None
self.audio_tensor = None
# Only log success in debug mode
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:

View File

@@ -265,11 +265,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.peft_format = peft_format
self.is_transformer = is_transformer
# use the old format for older models unless the user has specified otherwise
self.use_old_lokr_format = False
if self.network_config is not None and hasattr(self.network_config, 'old_lokr_format'):
self.use_old_lokr_format = self.network_config.old_lokr_format
# also allow a false from the model itself
if base_model is not None and not base_model.use_old_lokr_format:
self.use_old_lokr_format = False
# always do peft for flux only for now
if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer:
# don't do peft format for lokr
if self.network_type.lower() != "lokr":
# don't do peft format for lokr if using old format
if self.network_type.lower() != "lokr" or not self.use_old_lokr_format:
self.peft_format = True
if self.peft_format:

View File

@@ -185,6 +185,11 @@ class BaseModel:
self.has_multiple_control_images = False
# do not resize control images
self.use_raw_control_images = False
# defines if the model supports model paths. Only some will
self.supports_model_paths = False
# use new lokr format (default false for old models for backwards compatibility)
self.use_old_lokr_format = True
# properties for old arch for backwards compatibility
@property

View File

@@ -11,6 +11,7 @@ from toolkit.network_mixins import ToolkitModuleMixin
from typing import TYPE_CHECKING, Union, List
from optimum.quanto import QBytesTensor, QTensor
from torchao.dtypes import AffineQuantizedTensor
if TYPE_CHECKING:
@@ -284,17 +285,26 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
org_sd[weight_key] = merged_weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def get_orig_weight(self):
def get_orig_weight(self, device):
weight = self.org_module[0].weight
if weight.device != device:
weight = weight.to(device)
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
return weight.dequantize().data.detach()
elif isinstance(weight, AffineQuantizedTensor):
return weight.dequantize().data.detach()
else:
return weight.data.detach()
def get_orig_bias(self):
def get_orig_bias(self, device):
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
return self.org_module[0].bias.dequantize().data.detach()
bias = self.org_module[0].bias
if bias.device != device:
bias = bias.to(device)
if isinstance(bias, QTensor) or isinstance(bias, QBytesTensor):
return bias.dequantize().data.detach()
elif isinstance(bias, AffineQuantizedTensor):
return bias.dequantize().data.detach()
else:
return self.org_module[0].bias.data.detach()
return None
@@ -305,7 +315,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
orig_dtype = x.dtype
orig_weight = self.get_orig_weight()
orig_weight = self.get_orig_weight(x.device)
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
multiplier = self.network_ref().torch_multiplier
@@ -319,7 +329,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
orig_weight
+ lokr_weight * multiplier
)
bias = self.get_orig_bias()
bias = self.get_orig_bias(x.device)
if bias is not None:
bias = bias.to(weight.device, dtype=weight.dtype)
output = self.op(

View File

@@ -546,7 +546,8 @@ class ToolkitNetworkMixin:
new_save_dict = {}
for key, value in save_dict.items():
if key.endswith('.alpha'):
# lokr needs alpha
if key.endswith('.alpha') and self.network_type.lower() != "lokr":
continue
new_key = key
new_key = new_key.replace('lora_down', 'lora_A')
@@ -558,7 +559,7 @@ class ToolkitNetworkMixin:
save_dict = new_save_dict
if self.network_type.lower() == "lokr":
if self.network_type.lower() == "lokr" and self.use_old_lokr_format:
new_save_dict = {}
for key, value in save_dict.items():
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
@@ -632,7 +633,7 @@ class ToolkitNetworkMixin:
# lora_down = lora_A
# lora_up = lora_B
# no alpha
if load_key.endswith('.alpha'):
if load_key.endswith('.alpha') and self.network_type.lower() != "lokr":
continue
load_key = load_key.replace('lora_A', 'lora_down')
load_key = load_key.replace('lora_B', 'lora_up')
@@ -640,6 +641,13 @@ class ToolkitNetworkMixin:
load_key = load_key.replace('.', '$$')
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
# patch lokr, not sure why we need to but whatever
if self.network_type.lower() == "lokr":
load_key = load_key.replace('$$lokr_w1', '.lokr_w1')
load_key = load_key.replace('$$lokr_w2', '.lokr_w2')
if load_key.endswith('$$alpha'):
load_key = load_key[:-7] + '.alpha'
if self.network_type.lower() == "lokr":
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1

View File

@@ -223,6 +223,11 @@ class StableDiffusion:
self.has_multiple_control_images = False
# do not resize control images
self.use_raw_control_images = False
# defines if the model supports model paths. Only some will
self.supports_model_paths = False
# use new lokr format (default false for old models for backwards compatibility)
self.use_old_lokr_format = True
# properties for old arch for backwards compatibility
@property

View File

@@ -15,7 +15,7 @@ export async function POST(request: Request) {
}
// make sure it is an image
if (!/\.(jpg|jpeg|png|bmp|gif|tiff|webp)$/i.test(imgPath.toLowerCase())) {
if (!/\.(jpg|jpeg|png|bmp|gif|tiff|webp|mp4)$/i.test(imgPath.toLowerCase())) {
return NextResponse.json({ error: 'Not an image' }, { status: 400 });
}

View File

@@ -29,7 +29,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
const samples = fs
.readdirSync(samplesFolder)
.filter(file => {
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp');
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp') || file.endsWith('.mp4');
})
.map(file => {
return path.join(samplesFolder, file);

View File

@@ -862,6 +862,48 @@ export default function SimpleJob({
docKey="datasets.do_i2v"
/>
)}
{modelArch?.additionalSections?.includes('datasets.do_audio') && (
<Checkbox
label="Do Audio"
checked={dataset.do_audio || false}
onChange={value => {
if (!value) {
setJobConfig(undefined, `config.process[0].datasets[${i}].do_audio`);
} else {
setJobConfig(value, `config.process[0].datasets[${i}].do_audio`);
}
}}
docKey="datasets.do_audio"
/>
)}
{modelArch?.additionalSections?.includes('datasets.audio_normalize') && (
<Checkbox
label="Audio Normalize"
checked={dataset.audio_normalize || false}
onChange={value => {
if (!value) {
setJobConfig(undefined, `config.process[0].datasets[${i}].audio_normalize`);
} else {
setJobConfig(value, `config.process[0].datasets[${i}].audio_normalize`);
}
}}
docKey="datasets.audio_normalize"
/>
)}
{modelArch?.additionalSections?.includes('datasets.audio_preserve_pitch') && (
<Checkbox
label="Audio Preserve Pitch"
checked={dataset.audio_preserve_pitch || false}
onChange={value => {
if (!value) {
setJobConfig(undefined, `config.process[0].datasets[${i}].audio_preserve_pitch`);
} else {
setJobConfig(value, `config.process[0].datasets[${i}].audio_preserve_pitch`);
}
}}
docKey="datasets.audio_preserve_pitch"
/>
)}
</FormGroup>
<FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2">
<Checkbox

View File

@@ -14,7 +14,6 @@ export const defaultDatasetConfig: DatasetConfig = {
controls: [],
shrink_video_to_frames: true,
num_frames: 1,
do_i2v: true,
flip_x: false,
flip_y: false,
};

View File

@@ -17,6 +17,9 @@ type AdditionalSections =
| 'datasets.control_path'
| 'datasets.multi_control_paths'
| 'datasets.do_i2v'
| 'datasets.do_audio'
| 'datasets.audio_normalize'
| 'datasets.audio_preserve_pitch'
| 'sample.ctrl_img'
| 'sample.multi_ctrl_imgs'
| 'datasets.num_frames'
@@ -288,6 +291,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.width': [768, 1024],
'config.process[0].sample.height': [768, 1024],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].datasets[x].do_i2v': [true, undefined],
},
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'],
@@ -601,6 +605,31 @@ export const modelArchs: ModelArch[] = [
disableSections: ['network.conv'],
additionalSections: ['model.low_vram', 'model.layer_offloading'],
},
{
name: 'ltx2',
label: 'LTX-2',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Lightricks/LTX-2', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.low_vram': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [121, 1],
'config.process[0].sample.fps': [24, 1],
'config.process[0].sample.width': [768, 1024],
'config.process[0].sample.height': [768, 1024],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
'config.process[0].datasets[x].do_i2v': [false, undefined],
'config.process[0].datasets[x].do_audio': [true, undefined],
'config.process[0].datasets[x].fps': [24, undefined],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch'],
},
].sort((a, b) => {
// Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });

View File

@@ -2,6 +2,25 @@ import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
import { modelArchs, ModelArch } from './options';
import { objectCopy } from '@/utils/basic';
const expandDatasetDefaults = (
defaults: { [key: string]: any },
numDatasets: number,
): { [key: string]: any } => {
// expands the defaults for datasets[x] to datasets[0], datasets[1], etc.
const expandedDefaults: { [key: string]: any } = { ...defaults };
for (const key in defaults) {
if (key.includes('datasets[x].')) {
for (let i = 0; i < numDatasets; i++) {
const datasetKey = key.replace('datasets[x].', `datasets[${i}].`);
const v = defaults[key];
expandedDefaults[datasetKey] = Array.isArray(v) ? [...v] : objectCopy(v);
}
delete expandedDefaults[key];
}
}
return expandedDefaults;
};
export const handleModelArchChange = (
currentArchName: string,
newArchName: string,
@@ -39,16 +58,11 @@ export const handleModelArchChange = (
}
}
// revert defaults from previous model
for (const key in currentArch.defaults) {
setJobConfig(currentArch.defaults[key][1], key);
}
const numDatasets = jobConfig.config.process[0].datasets.length;
let currentDefaults = expandDatasetDefaults(currentArch.defaults || {}, numDatasets);
let newDefaults = expandDatasetDefaults(newArch?.defaults || {}, numDatasets);
if (newArch?.defaults) {
for (const key in newArch.defaults) {
setJobConfig(newArch.defaults[key][0], key);
}
}
// set new model
setJobConfig(newArchName, 'config.process[0].model.arch');
@@ -79,27 +93,27 @@ export const handleModelArchChange = (
if (newDataset.control_path_1 && newDataset.control_path_1 !== '') {
newDataset.control_path = newDataset.control_path_1;
}
if (newDataset.control_path_1) {
if ('control_path_1' in newDataset) {
delete newDataset.control_path_1;
}
if (newDataset.control_path_2) {
if ('control_path_2' in newDataset) {
delete newDataset.control_path_2;
}
if (newDataset.control_path_3) {
if ('control_path_3' in newDataset) {
delete newDataset.control_path_3;
}
} else {
// does not have control images
if (newDataset.control_path) {
if ('control_path' in newDataset) {
delete newDataset.control_path;
}
if (newDataset.control_path_1) {
if ('control_path_1' in newDataset) {
delete newDataset.control_path_1;
}
if (newDataset.control_path_2) {
if ('control_path_2' in newDataset) {
delete newDataset.control_path_2;
}
if (newDataset.control_path_3) {
if ('control_path_3' in newDataset) {
delete newDataset.control_path_3;
}
}
@@ -120,4 +134,13 @@ export const handleModelArchChange = (
return newSample;
});
setJobConfig(samples, 'config.process[0].sample.samples');
// revert defaults from previous model
for (const key in currentDefaults) {
setJobConfig(currentDefaults[key][1], key);
}
for (const key in newDefaults) {
setJobConfig(newDefaults[key][0], key);
}
};

View File

@@ -56,18 +56,6 @@ const SampleImageCard: React.FC<SampleImageCardProps> = ({
return () => observer.disconnect();
}, [observerRoot, rootMargin]);
// Pause video when leaving viewport
useEffect(() => {
if (!isVideo(imageUrl)) return;
const v = videoRef.current;
if (!v) return;
if (!isVisible && !v.paused) {
try {
v.pause();
} catch {}
}
}, [isVisible, imageUrl]);
const handleLoad = () => setLoaded(true);
return (
@@ -81,9 +69,11 @@ const SampleImageCard: React.FC<SampleImageCardProps> = ({
src={`/api/img/${encodeURIComponent(imageUrl)}`}
className="w-full h-full object-cover"
preload="none"
onLoad={handleLoad}
playsInline
muted
loop
autoPlay
controls={false}
/>
) : (

View File

@@ -7,6 +7,7 @@ import { Cog } from 'lucide-react';
import { Menu, MenuButton, MenuItem, MenuItems } from '@headlessui/react';
import { openConfirm } from './ConfirmModal';
import { apiClient } from '@/utils/api';
import { isVideo } from '@/utils/basic';
interface Props {
imgPath: string | null; // current image path
@@ -200,13 +201,24 @@ export default function SampleImageViewer({
className="relative transform rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in max-w-[95%] max-h-[95vh] data-closed:sm:translate-y-0 data-closed:sm:scale-95 flex flex-col overflow-hidden"
>
<div className="overflow-hidden flex items-center justify-center">
{imgPath && (
<img
src={`/api/img/${encodeURIComponent(imgPath)}`}
alt="Sample Image"
className="w-auto h-auto max-w-[95vw] max-h-[82vh] object-contain"
/>
)}
{imgPath &&
(isVideo(imgPath) ? (
<video
src={`/api/img/${encodeURIComponent(imgPath)}`}
className="w-auto h-auto max-w-[95vw] max-h-[82vh] object-contain"
preload="none"
playsInline
loop
autoPlay
controls={true}
/>
) : (
<img
src={`/api/img/${encodeURIComponent(imgPath)}`}
alt="Sample Image"
className="w-auto h-auto max-w-[95vw] max-h-[82vh] object-contain"
/>
))}
</div>
{/* # make full width */}
<div className="bg-gray-950 text-sm flex justify-between items-center px-4 py-2">

View File

@@ -107,6 +107,35 @@ const docs: { [key: string]: ConfigDoc } = {
</>
),
},
'datasets.do_audio': {
title: 'Do Audio',
description: (
<>
For models that support audio with video, this option will load the audio from the video and resize it to match
the video sequence. Since the video is automatically resized, the audio may drop or raise in pitch to match the new
speed of the video. It is important to prep your dataset to have the proper length before training.
</>
),
},
'datasets.audio_normalize': {
title: 'Audio Normalize',
description: (
<>
When loading audio, this will normalize the audio volume to the max peaks. Useful if your dataset has varying audio
volumes. Warning, do not use if you have clips with full silence you want to keep, as it will raise the volume of those clips.
</>
),
},
'datasets.audio_preserve_pitch': {
title: 'Audio Preserve Pitch',
description: (
<>
When loading audio to match the number of frames requested, this option will preserve the pitch of the audio if
the length does not match training target. It is recommended to have a dataset that matches your target length,
as this option can add sound distortions.
</>
),
},
'datasets.flip': {
title: 'Flip X and Flip Y',
description: (

View File

@@ -96,7 +96,11 @@ export interface DatasetConfig {
control_path?: string | null;
num_frames: number;
shrink_video_to_frames: boolean;
do_i2v: boolean;
do_i2v?: boolean;
do_audio?: boolean;
audio_normalize?: boolean;
audio_preserve_pitch?: boolean;
fps?: number;
flip_x: boolean;
flip_y: boolean;
control_path_1?: string | null;

View File

@@ -17,21 +17,14 @@ export function setNestedValue<T, V>(obj: T, value: V, path?: string): T {
}
// Split the path into segments
const pathArray = path.split('.').flatMap(segment => {
// Handle array notation like 'process[0]'
const arrayMatch = segment.match(/^([^\[]+)(\[\d+\])+/);
if (arrayMatch) {
const propName = arrayMatch[1];
const indices = segment
.substring(propName.length)
.match(/\[(\d+)\]/g)
?.map(idx => parseInt(idx.substring(1, idx.length - 1)));
const pathArray: Array<string | number> = [];
const re = /([^[.\]]+)|\[(\d+)\]/g;
let m: RegExpExecArray | null;
// Return property name followed by array indices
return [propName, ...(indices || [])];
}
return segment;
});
while ((m = re.exec(path)) !== null) {
if (m[1] !== undefined) pathArray.push(m[1]);
else pathArray.push(Number(m[2]));
}
// Navigate to the target location
let current: any = result;
@@ -43,8 +36,18 @@ export function setNestedValue<T, V>(obj: T, value: V, path?: string): T {
if (!Array.isArray(current)) {
throw new Error(`Cannot access index ${key} of non-array`);
}
// Create a copy of the array to maintain immutability
current = [...current];
// Ensure the indexed element exists and is copied/created immutably
const nextKey = pathArray[i + 1];
const existing = current[key];
if (existing === undefined) {
current[key] = typeof nextKey === 'number' ? [] : {};
} else if (Array.isArray(existing)) {
current[key] = [...existing];
} else if (typeof existing === 'object' && existing !== null) {
current[key] = { ...existing };
} // else: primitives stay as-is
} else {
// For object properties, create a new object if it doesn't exist
if (current[key] === undefined) {
@@ -63,7 +66,11 @@ export function setNestedValue<T, V>(obj: T, value: V, path?: string): T {
// Set the value at the final path segment
const finalKey = pathArray[pathArray.length - 1];
current[finalKey] = value;
if (value === undefined) {
delete current[finalKey];
} else {
current[finalKey] = value;
}
return result;
}

View File

@@ -1 +1 @@
VERSION = "0.7.16"
VERSION = "0.7.17"