mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-14 06:57:35 +00:00
Add LTX-2 Support (#644)
* WIP, adding support for LTX2 * Training on images working * Fix loading comfy models * Handle converting and deconverting lora so it matches original format * Reworked ui to habdle ltx and propert dataset default overwriting. * Update the way lokr saves to it is more compatable with comfy * Audio loading and synchronization/resampling is working * Add audio to training. Does it work? Maybe, still testing. * Fixed fps default issue for sound * Have ui set fps for accurate audio mapping on ltx * Added audio procession options to the ui for ltx * Clean up requirements
This commit is contained in:
@@ -7,6 +7,7 @@ from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
|
||||
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
|
||||
from .flux2 import Flux2Model
|
||||
from .z_image import ZImageModel
|
||||
from .ltx2 import LTX2Model
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
# put a list of models here
|
||||
@@ -25,4 +26,5 @@ AI_TOOLKIT_MODELS = [
|
||||
QwenImageEditPlusModel,
|
||||
Flux2Model,
|
||||
ZImageModel,
|
||||
LTX2Model,
|
||||
]
|
||||
|
||||
1
extensions_built_in/diffusion_models/ltx2/__init__.py
Normal file
1
extensions_built_in/diffusion_models/ltx2/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .ltx2 import LTX2Model
|
||||
@@ -0,0 +1,648 @@
|
||||
# ref https://github.com/huggingface/diffusers/blob/249ae1f853be8775c7d0b3a26ca4fbc4f6aa9998/scripts/convert_ltx2_to_diffusers.py
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Input Patchify Projections
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
# Modulation Parameters
|
||||
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
|
||||
# substrings of the other modulation parameters below
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
# Transformer Blocks
|
||||
# Per-Block Cross Attention Modulatin Parameters
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# Decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"ups": "upsamplers",
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_inplace(
|
||||
state_dict: Dict[str, Any], old_key: str, new_key: str
|
||||
) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
if key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae_per_channel_statistics(
|
||||
key: str, state_dict: Dict[str, Any]
|
||||
) -> None:
|
||||
if key.startswith("per_channel_statistics"):
|
||||
new_key = ".".join(["decoder", key])
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def split_transformer_and_connector_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
connector_prefixes = (
|
||||
"video_embeddings_connector",
|
||||
"audio_embeddings_connector",
|
||||
"transformer_1d_blocks",
|
||||
"text_embedding_projection.aggregate_embed",
|
||||
"connectors.",
|
||||
"video_connector",
|
||||
"audio_connector",
|
||||
"text_proj_in",
|
||||
)
|
||||
|
||||
transformer_state_dict, connector_state_dict = {}, {}
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith(connector_prefixes):
|
||||
connector_state_dict[key] = value
|
||||
else:
|
||||
transformer_state_dict[key] = value
|
||||
|
||||
return transformer_state_dict, connector_state_dict
|
||||
|
||||
|
||||
def get_ltx2_transformer_config() -> Tuple[
|
||||
Dict[str, Any], Dict[str, Any], Dict[str, Any]
|
||||
]:
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 3840,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def get_ltx2_connectors_config() -> Tuple[
|
||||
Dict[str, Any], Dict[str, Any], Dict[str, Any]
|
||||
]:
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 30,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_transformer(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_transformer_config()
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
transformer_state_dict, _ = split_transformer_and_connector_state_dict(
|
||||
original_state_dict
|
||||
)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(transformer_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, transformer_state_dict)
|
||||
|
||||
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_ltx2_connectors(original_state_dict: Dict[str, Any]) -> LTX2TextConnectors:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_connectors_config()
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
_, connector_state_dict = split_transformer_and_connector_state_dict(
|
||||
original_state_dict
|
||||
)
|
||||
if len(connector_state_dict) == 0:
|
||||
raise ValueError("No connector weights found in the provided state dict.")
|
||||
|
||||
with init_empty_weights():
|
||||
connectors = LTX2TextConnectors.from_config(diffusers_config)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(connector_state_dict, key, new_key)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, connector_state_dict)
|
||||
|
||||
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config() -> Tuple[
|
||||
Dict[str, Any], Dict[str, Any], Dict[str, Any]
|
||||
]:
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": (
|
||||
"spatial",
|
||||
"temporal",
|
||||
"spatiotemporal",
|
||||
"spatiotemporal",
|
||||
),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config()
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_audio_vae_config() -> Tuple[
|
||||
Dict[str, Any], Dict[str, Any], Dict[str, Any]
|
||||
]:
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 256,
|
||||
"latent_channels": 8,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"double_z": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config()
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_vocoder_config() -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 24000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config()
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vocoder
|
||||
|
||||
|
||||
def get_model_state_dict_from_combined_ckpt(
|
||||
combined_ckpt: Dict[str, Any], prefix: str
|
||||
) -> Dict[str, Any]:
|
||||
# Ensure that the key prefix ends with a dot (.)
|
||||
if not prefix.endswith("."):
|
||||
prefix = prefix + "."
|
||||
|
||||
model_state_dict = {}
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
|
||||
if prefix == "model.diffusion_model.":
|
||||
# Some checkpoints store the text connector projection outside the diffusion model prefix.
|
||||
connector_key = "text_embedding_projection.aggregate_embed.weight"
|
||||
if connector_key in combined_ckpt and connector_key not in model_state_dict:
|
||||
model_state_dict[connector_key] = combined_ckpt[connector_key]
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def dequantize_state_dict(state_dict: Dict[str, Any]):
|
||||
keys = list(state_dict.keys())
|
||||
state_out = {}
|
||||
for k in keys:
|
||||
if k.endswith(
|
||||
(".weight_scale", ".weight_scale_2", ".pre_quant_scale", ".input_scale")
|
||||
):
|
||||
continue
|
||||
|
||||
t = state_dict[k]
|
||||
|
||||
if k.endswith(".weight"):
|
||||
prefix = k[: -len(".weight")]
|
||||
wscale_k = prefix + ".weight_scale"
|
||||
if wscale_k in state_dict:
|
||||
w_q = t
|
||||
w_scale = state_dict[wscale_k]
|
||||
# Comfy quant = absmax per-tensor weight quant, nothing fancy
|
||||
w_bf16 = w_q.to(torch.bfloat16) * w_scale.to(torch.bfloat16)
|
||||
state_out[k] = w_bf16
|
||||
continue
|
||||
|
||||
state_out[k] = t
|
||||
return state_out
|
||||
|
||||
|
||||
def convert_comfy_gemma3_to_transformers(sd: dict):
|
||||
out = {}
|
||||
|
||||
sd = dequantize_state_dict(sd)
|
||||
|
||||
for k, v in sd.items():
|
||||
nk = k
|
||||
|
||||
# Vision tower weights: checkpoint has "vision_model.*"
|
||||
# model expects "model.vision_tower.vision_model.*"
|
||||
if k.startswith("vision_model."):
|
||||
nk = "model.vision_tower." + k
|
||||
|
||||
# MM projector: checkpoint has "multi_modal_projector.*"
|
||||
# model expects "model.multi_modal_projector.*"
|
||||
elif k.startswith("multi_modal_projector."):
|
||||
nk = "model." + k
|
||||
|
||||
# Language model: checkpoint has "model.layers.*", "model.embed_tokens.*", "model.norm.*"
|
||||
# model expects "model.language_model.layers.*", etc.
|
||||
elif k == "model.embed_tokens.weight":
|
||||
nk = "model.language_model.embed_tokens.weight"
|
||||
elif k.startswith("model.layers."):
|
||||
nk = "model.language_model.layers." + k[len("model.layers.") :]
|
||||
elif k.startswith("model.norm."):
|
||||
nk = "model.language_model.norm." + k[len("model.norm.") :]
|
||||
|
||||
# (optional) common DDP prefix
|
||||
if nk.startswith("module."):
|
||||
nk = nk[len("module.") :]
|
||||
|
||||
# skip spiece_model
|
||||
if nk == "spiece_model":
|
||||
continue
|
||||
|
||||
out[nk] = v
|
||||
|
||||
# If lm_head is missing but embeddings exist, many Gemma-family models tie these weights.
|
||||
# Add it so strict loading won't complain (or just load strict=False and call tie_weights()).
|
||||
if (
|
||||
"lm_head.weight" not in out
|
||||
and "model.language_model.embed_tokens.weight" in out
|
||||
):
|
||||
out["lm_head.weight"] = out["model.language_model.embed_tokens.weight"]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def convert_lora_original_to_diffusers(
|
||||
lora_state_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
out: Dict[str, Any] = {}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
|
||||
for k, v in lora_state_dict.items():
|
||||
# Keep the "diffusion_model." prefix as-is, but apply the transformer remaps to the rest
|
||||
prefix = ""
|
||||
rest = k
|
||||
if rest.startswith("diffusion_model."):
|
||||
prefix = "diffusion_model."
|
||||
rest = rest[len(prefix) :]
|
||||
|
||||
nk = rest
|
||||
|
||||
# Same simple 1:1 remaps as the transformer
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
nk = nk.replace(replace_key, rename_key)
|
||||
|
||||
# Same special-case remap as the transformer (applies to LoRA keys too)
|
||||
if nk.startswith("adaln_single."):
|
||||
nk = nk.replace("adaln_single.", "time_embed.", 1)
|
||||
elif nk.startswith("audio_adaln_single."):
|
||||
nk = nk.replace("audio_adaln_single.", "audio_time_embed.", 1)
|
||||
|
||||
out[prefix + nk] = v
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def convert_lora_diffusers_to_original(
|
||||
lora_state_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
out: Dict[str, Any] = {}
|
||||
|
||||
inv_rename = {v: k for k, v in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items()}
|
||||
inv_items = sorted(inv_rename.items(), key=lambda kv: len(kv[0]), reverse=True)
|
||||
|
||||
for k, v in lora_state_dict.items():
|
||||
# Keep the "diffusion_model." prefix as-is, but invert remaps on the rest
|
||||
prefix = ""
|
||||
rest = k
|
||||
if rest.startswith("diffusion_model."):
|
||||
prefix = "diffusion_model."
|
||||
rest = rest[len(prefix) :]
|
||||
|
||||
nk = rest
|
||||
|
||||
# Inverse of the adaln_single special-case
|
||||
if nk.startswith("time_embed."):
|
||||
nk = nk.replace("time_embed.", "adaln_single.", 1)
|
||||
elif nk.startswith("audio_time_embed."):
|
||||
nk = nk.replace("audio_time_embed.", "audio_adaln_single.", 1)
|
||||
|
||||
# Inverse 1:1 remaps
|
||||
for diffusers_key, original_key in inv_items:
|
||||
nk = nk.replace(diffusers_key, original_key)
|
||||
|
||||
out[prefix + nk] = v
|
||||
|
||||
return out
|
||||
852
extensions_built_in/diffusion_models/ltx2/ltx2.py
Normal file
852
extensions_built_in/diffusion_models/ltx2/ltx2.py
Normal file
@@ -0,0 +1,852 @@
|
||||
from functools import partial
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers import Gemma3Config
|
||||
import yaml
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||
CustomFlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from accelerate import init_empty_weights
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from optimum.quanto import freeze
|
||||
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||
from toolkit.memory_management import MemoryManager
|
||||
from safetensors.torch import load_file
|
||||
|
||||
try:
|
||||
from diffusers import LTX2Pipeline
|
||||
from diffusers.models.autoencoders import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
)
|
||||
from diffusers.models.transformers import LTX2VideoTransformer3DModel
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
from transformers import (
|
||||
Gemma3ForConditionalGeneration,
|
||||
GemmaTokenizerFast,
|
||||
)
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
|
||||
from .convert_ltx2_to_diffusers import (
|
||||
get_model_state_dict_from_combined_ckpt,
|
||||
convert_ltx2_transformer,
|
||||
convert_ltx2_video_vae,
|
||||
convert_ltx2_audio_vae,
|
||||
convert_ltx2_vocoder,
|
||||
convert_ltx2_connectors,
|
||||
dequantize_state_dict,
|
||||
convert_comfy_gemma3_to_transformers,
|
||||
convert_lora_original_to_diffusers,
|
||||
convert_lora_diffusers_to_original,
|
||||
)
|
||||
except ImportError as e:
|
||||
print("Diffusers import error:", e)
|
||||
raise ImportError(
|
||||
"Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt"
|
||||
)
|
||||
|
||||
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 1024,
|
||||
"base_shift": 0.95,
|
||||
"invert_sigmas": False,
|
||||
"max_image_seq_len": 4096,
|
||||
"max_shift": 2.05,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 1.0,
|
||||
"shift_terminal": 0.1,
|
||||
"stochastic_sampling": False,
|
||||
"time_shift_type": "exponential",
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": True,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_karras_sigmas": False,
|
||||
}
|
||||
|
||||
dit_prefix = "model.diffusion_model."
|
||||
vae_prefix = "vae."
|
||||
audio_vae_prefix = "audio_vae."
|
||||
vocoder_prefix = "vocoder."
|
||||
|
||||
|
||||
def new_save_image_function(
|
||||
self: GenerateImageConfig,
|
||||
image, # will contain a dict that can be dumped ditectly into encode_video, just add output_path to it.
|
||||
count: int = 0,
|
||||
max_count: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
# this replaces gen image config save image function so we can save the video with sound from ltx2
|
||||
image["output_path"] = self.get_image_path(count, max_count)
|
||||
# make sample directory if it does not exist
|
||||
os.makedirs(os.path.dirname(image["output_path"]), exist_ok=True)
|
||||
encode_video(**image)
|
||||
flush()
|
||||
|
||||
|
||||
def blank_log_image_function(self, *args, **kwargs):
|
||||
# todo handle wandb logging of videos with audio
|
||||
return
|
||||
|
||||
|
||||
class AudioProcessor(torch.nn.Module):
|
||||
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int,
|
||||
mel_bins: int,
|
||||
mel_hop_length: int,
|
||||
n_fft: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=n_fft,
|
||||
hop_length=mel_hop_length,
|
||||
f_min=0.0,
|
||||
f_max=sample_rate / 2.0,
|
||||
n_mels=mel_bins,
|
||||
window_fn=torch.hann_window,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
power=1.0,
|
||||
mel_scale="slaney",
|
||||
norm="slaney",
|
||||
)
|
||||
|
||||
def resample_waveform(
|
||||
self,
|
||||
waveform: torch.Tensor,
|
||||
source_rate: int,
|
||||
target_rate: int,
|
||||
) -> torch.Tensor:
|
||||
"""Resample waveform to target sample rate if needed."""
|
||||
if source_rate == target_rate:
|
||||
return waveform
|
||||
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||
return resampled.to(device=waveform.device, dtype=waveform.dtype)
|
||||
|
||||
def waveform_to_mel(
|
||||
self,
|
||||
waveform: torch.Tensor,
|
||||
waveform_sample_rate: int,
|
||||
) -> torch.Tensor:
|
||||
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
|
||||
waveform = self.resample_waveform(
|
||||
waveform, waveform_sample_rate, self.sample_rate
|
||||
)
|
||||
|
||||
mel = self.mel_transform(waveform)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
|
||||
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
|
||||
return mel.permute(0, 1, 3, 2).contiguous()
|
||||
|
||||
|
||||
class LTX2Model(BaseModel):
|
||||
arch = "ltx2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ["LTX2VideoTransformer3DModel"]
|
||||
# defines if the model supports model paths. Only some will
|
||||
self.supports_model_paths = True
|
||||
# use the new format on this new model by default
|
||||
self.use_old_lokr_format = False
|
||||
self.audio_processor = None
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
return 32
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
self.print_and_status_update("Loading LTX2 model")
|
||||
model_path = self.model_config.name_or_path
|
||||
base_model_path = self.model_config.extras_name_or_path
|
||||
|
||||
combined_state_dict = None
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
# if we have a safetensors file it is a mono checkpoint
|
||||
if os.path.exists(model_path) and model_path.endswith(".safetensors"):
|
||||
combined_state_dict = load_file(model_path)
|
||||
combined_state_dict = dequantize_state_dict(combined_state_dict)
|
||||
|
||||
if combined_state_dict is not None:
|
||||
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(
|
||||
combined_state_dict, dit_prefix
|
||||
)
|
||||
transformer = convert_ltx2_transformer(original_dit_ckpt)
|
||||
transformer = transformer.to(dtype)
|
||||
else:
|
||||
transformer_path = model_path
|
||||
transformer_subfolder = "transformer"
|
||||
if os.path.exists(transformer_path):
|
||||
transformer_subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, "transformer")
|
||||
# check if the path is a full checkpoint.
|
||||
te_folder_path = os.path.join(model_path, "text_encoder")
|
||||
# if we have the te, this folder is a full checkpoint, use it as the base
|
||||
if os.path.exists(te_folder_path):
|
||||
base_model_path = model_path
|
||||
|
||||
transformer = LTX2VideoTransformer3DModel.from_pretrained(
|
||||
transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype
|
||||
)
|
||||
|
||||
if self.model_config.quantize:
|
||||
self.print_and_status_update("Quantizing Transformer")
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
if (
|
||||
self.model_config.layer_offloading
|
||||
and self.model_config.layer_offloading_transformer_percent > 0
|
||||
):
|
||||
ignore_modules = []
|
||||
for block in transformer.transformer_blocks:
|
||||
ignore_modules.append(block.scale_shift_table)
|
||||
ignore_modules.append(block.audio_scale_shift_table)
|
||||
ignore_modules.append(block.video_a2v_cross_attn_scale_shift_table)
|
||||
ignore_modules.append(block.audio_a2v_cross_attn_scale_shift_table)
|
||||
ignore_modules.append(transformer.scale_shift_table)
|
||||
ignore_modules.append(transformer.audio_scale_shift_table)
|
||||
MemoryManager.attach(
|
||||
transformer,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_transformer_percent,
|
||||
ignore_modules=ignore_modules,
|
||||
)
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer to CPU")
|
||||
transformer.to("cpu")
|
||||
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading text encoder")
|
||||
if (
|
||||
self.model_config.te_name_or_path is not None
|
||||
and self.model_config.te_name_or_path.endswith(".safetensors")
|
||||
):
|
||||
# load from comfyui gemma3 checkpoint
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained(
|
||||
"Lightricks/LTX-2", subfolder="tokenizer"
|
||||
)
|
||||
|
||||
with init_empty_weights():
|
||||
text_encoder = Gemma3ForConditionalGeneration(
|
||||
Gemma3Config(
|
||||
**{
|
||||
"boi_token_index": 255999,
|
||||
"bos_token_id": 2,
|
||||
"eoi_token_index": 256000,
|
||||
"eos_token_id": 106,
|
||||
"image_token_index": 262144,
|
||||
"initializer_range": 0.02,
|
||||
"mm_tokens_per_image": 256,
|
||||
"model_type": "gemma3",
|
||||
"pad_token_id": 0,
|
||||
"text_config": {
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": None,
|
||||
"cache_implementation": "hybrid",
|
||||
"final_logit_softcapping": None,
|
||||
"head_dim": 256,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 3840,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 15360,
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "gemma3_text",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 48,
|
||||
"num_key_value_heads": 8,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_local_base_freq": 10000,
|
||||
"rope_scaling": {"factor": 8.0, "rope_type": "linear"},
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": 1024,
|
||||
"sliding_window_pattern": 6,
|
||||
"torch_dtype": "bfloat16",
|
||||
"use_cache": True,
|
||||
"vocab_size": 262208,
|
||||
},
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.3",
|
||||
"unsloth_fixed": True,
|
||||
"vision_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 896,
|
||||
"intermediate_size": 4304,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"torch_dtype": "bfloat16",
|
||||
"vision_use_head": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
te_state_dict = load_file(self.model_config.te_name_or_path)
|
||||
te_state_dict = convert_comfy_gemma3_to_transformers(te_state_dict)
|
||||
for key in te_state_dict:
|
||||
te_state_dict[key] = te_state_dict[key].to(dtype)
|
||||
|
||||
text_encoder.load_state_dict(te_state_dict, assign=True, strict=True)
|
||||
del te_state_dict
|
||||
flush()
|
||||
else:
|
||||
if self.model_config.te_name_or_path is not None:
|
||||
te_path = self.model_config.te_name_or_path
|
||||
else:
|
||||
te_path = base_model_path
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained(
|
||||
te_path, subfolder="tokenizer"
|
||||
)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
te_path, subfolder="text_encoder", dtype=dtype
|
||||
)
|
||||
|
||||
# remove the vision tower
|
||||
text_encoder.model.vision_tower = None
|
||||
flush()
|
||||
|
||||
if (
|
||||
self.model_config.layer_offloading
|
||||
and self.model_config.layer_offloading_text_encoder_percent > 0
|
||||
):
|
||||
MemoryManager.attach(
|
||||
text_encoder,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
||||
ignore_modules=[
|
||||
text_encoder.model.language_model.base_model.embed_tokens
|
||||
],
|
||||
)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Text Encoder")
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading VAEs and other components")
|
||||
if combined_state_dict is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(
|
||||
combined_state_dict, vae_prefix
|
||||
)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt).to(dtype)
|
||||
del original_vae_ckpt
|
||||
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(
|
||||
combined_state_dict, audio_vae_prefix
|
||||
)
|
||||
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt).to(dtype)
|
||||
del original_audio_vae_ckpt
|
||||
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(
|
||||
combined_state_dict, dit_prefix
|
||||
)
|
||||
connectors = convert_ltx2_connectors(original_connectors_ckpt).to(dtype)
|
||||
del original_connectors_ckpt
|
||||
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(
|
||||
combined_state_dict, vocoder_prefix
|
||||
)
|
||||
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt).to(dtype)
|
||||
del original_vocoder_ckpt
|
||||
del combined_state_dict
|
||||
flush()
|
||||
else:
|
||||
vae = AutoencoderKLLTX2Video.from_pretrained(
|
||||
base_model_path, subfolder="vae", torch_dtype=dtype
|
||||
)
|
||||
audio_vae = AutoencoderKLLTX2Audio.from_pretrained(
|
||||
base_model_path, subfolder="audio_vae", torch_dtype=dtype
|
||||
)
|
||||
|
||||
connectors = LTX2TextConnectors.from_pretrained(
|
||||
base_model_path, subfolder="connectors", torch_dtype=dtype
|
||||
)
|
||||
|
||||
vocoder = LTX2Vocoder.from_pretrained(
|
||||
base_model_path, subfolder="vocoder", torch_dtype=dtype
|
||||
)
|
||||
|
||||
self.noise_scheduler = LTX2Model.get_train_scheduler()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
pipe: LTX2Pipeline = LTX2Pipeline(
|
||||
scheduler=self.noise_scheduler,
|
||||
vae=vae,
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=None,
|
||||
tokenizer=tokenizer,
|
||||
connectors=connectors,
|
||||
transformer=None,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
# for quantization, it works best to do these after making the pipe
|
||||
pipe.text_encoder = text_encoder
|
||||
pipe.transformer = transformer
|
||||
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = [pipe.text_encoder]
|
||||
tokenizer = [pipe.tokenizer]
|
||||
|
||||
# leave it on cpu for now
|
||||
if not self.low_vram:
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
|
||||
flush()
|
||||
# just to make sure everything is on the right device and dtype
|
||||
text_encoder[0].to(self.device_torch)
|
||||
text_encoder[0].requires_grad_(False)
|
||||
text_encoder[0].eval()
|
||||
flush()
|
||||
|
||||
# save it to the model class
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder # list of text encoders
|
||||
self.tokenizer = tokenizer # list of tokenizers
|
||||
self.model = pipe.transformer
|
||||
self.pipeline = pipe
|
||||
|
||||
self.audio_processor = AudioProcessor(
|
||||
sample_rate=pipe.audio_sampling_rate,
|
||||
mel_bins=audio_vae.config.mel_bins,
|
||||
mel_hop_length=pipe.audio_hop_length,
|
||||
n_fft=1024, # todo get this from vae if we can, I couldnt find it.
|
||||
).to(self.device_torch, dtype=torch.float32)
|
||||
|
||||
self.print_and_status_update("Model Loaded")
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None):
|
||||
if device is None:
|
||||
device = self.vae_device_torch
|
||||
if dtype is None:
|
||||
dtype = self.vae_torch_dtype
|
||||
|
||||
if self.vae.device == torch.device("cpu"):
|
||||
self.vae.to(device)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
|
||||
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||
|
||||
# Normalize shapes
|
||||
norm_images = []
|
||||
for image in image_list:
|
||||
if image.ndim == 3:
|
||||
# (C, H, W) -> (C, 1, H, W)
|
||||
norm_images.append(image.unsqueeze(1))
|
||||
elif image.ndim == 4:
|
||||
# (T, C, H, W) -> (C, T, H, W)
|
||||
norm_images.append(image.permute(1, 0, 2, 3))
|
||||
else:
|
||||
raise ValueError(f"Invalid image shape: {image.shape}")
|
||||
|
||||
# Stack to (B, C, T, H, W)
|
||||
images = torch.stack(norm_images)
|
||||
|
||||
latents = self.vae.encode(images).latent_dist.mode()
|
||||
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
scaling_factor = 1.0
|
||||
latents_mean = self.pipeline.vae.latents_mean.view(1, -1, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents_std = self.pipeline.vae.latents_std.view(1, -1, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
|
||||
return latents.to(device, dtype=dtype)
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = LTX2Model.get_train_scheduler()
|
||||
|
||||
pipeline: LTX2Pipeline = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
vae=unwrap_model(self.pipeline.vae),
|
||||
audio_vae=unwrap_model(self.pipeline.audio_vae),
|
||||
text_encoder=None,
|
||||
tokenizer=unwrap_model(self.pipeline.tokenizer),
|
||||
connectors=unwrap_model(self.pipeline.connectors),
|
||||
transformer=None,
|
||||
vocoder=unwrap_model(self.pipeline.vocoder),
|
||||
)
|
||||
pipeline.transformer = unwrap_model(self.model)
|
||||
pipeline.text_encoder = unwrap_model(self.text_encoder[0])
|
||||
|
||||
# if self.low_vram:
|
||||
# pipeline.enable_model_cpu_offload(device=self.device_torch)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: LTX2Pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
if self.model.device == torch.device("cpu"):
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
is_video = gen_config.num_frames > 1
|
||||
# override the generate single image to handle video + audio generation
|
||||
if is_video:
|
||||
gen_config._orig_save_image_function = gen_config.save_image
|
||||
gen_config.save_image = partial(new_save_image_function, gen_config)
|
||||
gen_config.log_image = partial(blank_log_image_function, gen_config)
|
||||
# set output extension to mp4
|
||||
gen_config.output_ext = "mp4"
|
||||
|
||||
# reactivate progress bar since this is slooooow
|
||||
pipeline.set_progress_bar_config(disable=False)
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
# make sure dimensions are valid
|
||||
bd = self.get_bucket_divisibility()
|
||||
gen_config.height = (gen_config.height // bd) * bd
|
||||
gen_config.width = (gen_config.width // bd) * bd
|
||||
|
||||
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
|
||||
if gen_config.num_frames != 1:
|
||||
if (gen_config.num_frames - 1) % 8 != 0:
|
||||
gen_config.num_frames = ((gen_config.num_frames - 1) // 8) * 8 + 1
|
||||
|
||||
if self.low_vram:
|
||||
# set vae to tile decode
|
||||
pipeline.vae.enable_tiling(
|
||||
tile_sample_min_height=256,
|
||||
tile_sample_min_width=256,
|
||||
tile_sample_min_num_frames=8,
|
||||
tile_sample_stride_height=224,
|
||||
tile_sample_stride_width=224,
|
||||
tile_sample_stride_num_frames=4,
|
||||
)
|
||||
|
||||
video, audio = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype
|
||||
),
|
||||
prompt_attention_mask=conditional_embeds.attention_mask.to(
|
||||
self.device_torch
|
||||
),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype
|
||||
),
|
||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(
|
||||
self.device_torch
|
||||
),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
num_frames=gen_config.num_frames,
|
||||
generator=generator,
|
||||
return_dict=False,
|
||||
output_type="np" if is_video else "pil",
|
||||
**extra,
|
||||
)
|
||||
if self.low_vram:
|
||||
# Restore no tiling
|
||||
pipeline.vae.use_tiling = False
|
||||
|
||||
if is_video:
|
||||
# redurn as a dict, we will handle it with an override function
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
return {
|
||||
"video": video[0],
|
||||
"fps": gen_config.fps,
|
||||
"audio": audio[0].float().cpu(),
|
||||
"audio_sample_rate": pipeline.vocoder.config.output_sampling_rate, # should be 24000
|
||||
"output_path": None,
|
||||
}
|
||||
else:
|
||||
# shape = [1, frames, channels, height, width]
|
||||
# make sure this is right
|
||||
video = video[0] # list of pil images
|
||||
audio = audio[0] # tensor
|
||||
if gen_config.num_frames > 1:
|
||||
return video # return the frames.
|
||||
else:
|
||||
# get just the first image
|
||||
img = video[0]
|
||||
return img
|
||||
|
||||
def encode_audio(self, batch: "DataLoaderBatchDTO"):
|
||||
if self.pipeline.audio_vae.device == torch.device("cpu"):
|
||||
self.pipeline.audio_vae.to(self.device_torch)
|
||||
|
||||
output_tensor = None
|
||||
audio_num_frames = None
|
||||
|
||||
# do them seperatly for now
|
||||
for audio_data in batch.audio_data:
|
||||
waveform = audio_data["waveform"].to(
|
||||
device=self.device_torch, dtype=torch.float32
|
||||
)
|
||||
sample_rate = audio_data["sample_rate"]
|
||||
|
||||
# Add batch dimension if needed: [channels, samples] -> [batch, channels, samples]
|
||||
if waveform.dim() == 2:
|
||||
waveform = waveform.unsqueeze(0)
|
||||
|
||||
if waveform.shape[1] == 1:
|
||||
# make sure it is stereo
|
||||
waveform = waveform.repeat(1, 2, 1)
|
||||
|
||||
# Convert waveform to mel spectrogram using AudioProcessor
|
||||
mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate)
|
||||
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
|
||||
|
||||
# Encode mel spectrogram to latents
|
||||
latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode()
|
||||
|
||||
if audio_num_frames is None:
|
||||
audio_num_frames = latents.shape[2] #(latents is [B, C, T, F])
|
||||
|
||||
packed_latents = self.pipeline._pack_audio_latents(
|
||||
latents,
|
||||
# patch_size=self.pipeline.transformer.config.audio_patch_size,
|
||||
# patch_size_t=self.pipeline.transformer.config.audio_patch_size_t,
|
||||
) # [B, L, C * M]
|
||||
if output_tensor is None:
|
||||
output_tensor = packed_latents
|
||||
else:
|
||||
output_tensor = torch.cat([output_tensor, packed_latents], dim=0)
|
||||
|
||||
# normalize latents, opposite of (latents * latents_std) + latents_mean
|
||||
latents_mean = self.pipeline.audio_vae.latents_mean
|
||||
latents_std = self.pipeline.audio_vae.latents_std
|
||||
output_tensor = (output_tensor - latents_mean) / latents_std
|
||||
return output_tensor, audio_num_frames
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
batch: "DataLoaderBatchDTO" = None,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
if self.model.device == torch.device("cpu"):
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
batch_size, C, latent_num_frames, latent_height, latent_width = (
|
||||
latent_model_input.shape
|
||||
)
|
||||
|
||||
# todo get this somehow
|
||||
frame_rate = 24
|
||||
# check frame dimension
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
packed_latents = self.pipeline._pack_latents(
|
||||
latent_model_input,
|
||||
patch_size=self.pipeline.transformer_spatial_patch_size,
|
||||
patch_size_t=self.pipeline.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
if batch.audio_tensor is not None:
|
||||
# use audio from the batch if available
|
||||
#(1, 190, 128)
|
||||
raw_audio_latents, audio_num_frames = self.encode_audio(batch)
|
||||
# add the audio targets to the batch for loss calculation later
|
||||
audio_noise = torch.randn_like(raw_audio_latents)
|
||||
batch.audio_target = (audio_noise - raw_audio_latents).detach()
|
||||
audio_latents = self.add_noise(
|
||||
raw_audio_latents,
|
||||
audio_noise,
|
||||
timestep,
|
||||
).to(self.device_torch, dtype=self.torch_dtype)
|
||||
|
||||
else:
|
||||
# no audio
|
||||
num_mel_bins = self.pipeline.audio_vae.config.mel_bins
|
||||
# latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
num_channels_latents_audio = (
|
||||
self.pipeline.audio_vae.config.latent_channels
|
||||
)
|
||||
# audio latents are (1, 126, 128), audio_num_frames = 126
|
||||
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
|
||||
batch_size,
|
||||
num_channels_latents=num_channels_latents_audio,
|
||||
num_mel_bins=num_mel_bins,
|
||||
num_frames=batch.tensor.shape[1],
|
||||
frame_rate=frame_rate,
|
||||
sampling_rate=self.pipeline.audio_sampling_rate,
|
||||
hop_length=self.pipeline.audio_hop_length,
|
||||
dtype=torch.float32,
|
||||
device=self.transformer.device,
|
||||
generator=None,
|
||||
latents=None,
|
||||
)
|
||||
|
||||
if self.pipeline.connectors.device != self.transformer.device:
|
||||
self.pipeline.connectors.to(self.transformer.device)
|
||||
|
||||
# TODO this is how diffusers does this on inference, not sure I understand why, check this
|
||||
additive_attention_mask = (
|
||||
1 - text_embeddings.attention_mask.to(self.transformer.dtype)
|
||||
) * -1000000.0
|
||||
(
|
||||
connector_prompt_embeds,
|
||||
connector_audio_prompt_embeds,
|
||||
connector_attention_mask,
|
||||
) = self.pipeline.connectors(
|
||||
text_embeddings.text_embeds, additive_attention_mask, additive_mask=True
|
||||
)
|
||||
|
||||
# compute video and audio positional ids
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
packed_latents.shape[0],
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
packed_latents.device,
|
||||
fps=frame_rate,
|
||||
)
|
||||
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
|
||||
audio_latents.shape[0], audio_num_frames, audio_latents.device
|
||||
)
|
||||
|
||||
noise_pred_video, noise_pred_audio = self.transformer(
|
||||
hidden_states=packed_latents,
|
||||
audio_hidden_states=audio_latents.to(self.transformer.dtype),
|
||||
encoder_hidden_states=connector_prompt_embeds,
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=None,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# add audio latent to batch if we had audio
|
||||
if batch.audio_target is not None:
|
||||
batch.audio_pred = noise_pred_audio
|
||||
|
||||
unpacked_output = self.pipeline._unpack_latents(
|
||||
latents=noise_pred_video,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
patch_size=self.pipeline.transformer_spatial_patch_size,
|
||||
patch_size_t=self.pipeline.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
return unpacked_output
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
|
||||
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
device=self.device_torch,
|
||||
)
|
||||
pe = PromptEmbeds([prompt_embeds, None])
|
||||
pe.attention_mask = prompt_attention_mask
|
||||
return pe
|
||||
|
||||
def get_model_has_grad(self):
|
||||
return False
|
||||
|
||||
def get_te_has_grad(self):
|
||||
return False
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
transformer: LTX2VideoTransformer3DModel = unwrap_model(self.model)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_path, "transformer"),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
meta_path = os.path.join(output_path, "aitk_meta.yaml")
|
||||
with open(meta_path, "w") as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get("noise")
|
||||
batch = kwargs.get("batch")
|
||||
return (noise - batch.latents).detach()
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "ltx2"
|
||||
|
||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||
return ["transformer_blocks"]
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("transformer.", "diffusion_model.")
|
||||
new_sd[new_key] = value
|
||||
new_sd = convert_lora_diffusers_to_original(new_sd)
|
||||
return new_sd
|
||||
|
||||
def convert_lora_weights_before_load(self, state_dict):
|
||||
state_dict = convert_lora_original_to_diffusers(state_dict)
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("diffusion_model.", "transformer.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
@@ -859,6 +859,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
# check for audio loss
|
||||
if batch.audio_pred is not None and batch.audio_target is not None:
|
||||
audio_loss = torch.nn.functional.mse_loss(batch.audio_pred.float(), batch.audio_target.float(), reduction="mean")
|
||||
loss = loss + audio_loss
|
||||
|
||||
# check for additional losses
|
||||
if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
torchao==0.10.0
|
||||
safetensors
|
||||
git+https://github.com/huggingface/diffusers@f6b6a7181eb44f0120b29cd897c129275f366c2a
|
||||
git+https://github.com/huggingface/diffusers@8600b4c10d67b0ce200f664204358747bd53c775
|
||||
transformers==4.57.3
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
@@ -36,4 +36,6 @@ opencv-python
|
||||
pytorch-wavelets==1.3.0
|
||||
matplotlib==3.10.1
|
||||
setuptools==69.5.1
|
||||
scipy==1.12.0
|
||||
scipy==1.12.0
|
||||
av==16.0.1
|
||||
torchcodec
|
||||
234
testing/test_ltx_dataloader.py
Normal file
234
testing/test_ltx_dataloader.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import time
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torchvision.io import write_video
|
||||
import subprocess
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument('dataset_folder', type=str, default='input')
|
||||
parser.add_argument('dataset_folder', type=str)
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
parser.add_argument('--num_frames', type=int, default=121)
|
||||
parser.add_argument('--output_path', type=str, default='output/dataset_test')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output_path is None:
|
||||
raise ValueError('output_path is required for this test script')
|
||||
|
||||
if args.output_path is not None:
|
||||
args.output_path = os.path.abspath(args.output_path)
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
resolution = 512
|
||||
bucket_tolerance = 64
|
||||
batch_size = 1
|
||||
frame_rate = 24
|
||||
|
||||
|
||||
## make fake sd
|
||||
class FakeSD:
|
||||
def __init__(self):
|
||||
self.use_raw_control_images = False
|
||||
|
||||
def encode_control_in_text_embeddings(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
return 32
|
||||
|
||||
dataset_config = DatasetConfig(
|
||||
dataset_path=dataset_folder,
|
||||
resolution=resolution,
|
||||
default_caption='default',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
shrink_video_to_frames=True,
|
||||
num_frames=args.num_frames,
|
||||
do_i2v=True,
|
||||
fps=frame_rate,
|
||||
do_audio=True,
|
||||
debug=True,
|
||||
audio_preserve_pitch=False,
|
||||
audio_normalize=True
|
||||
|
||||
)
|
||||
|
||||
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD())
|
||||
|
||||
|
||||
def _tensor_to_uint8_video(frames_fchw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
frames_fchw: [F, C, H, W] float/uint8
|
||||
returns: [F, H, W, C] uint8 on CPU
|
||||
"""
|
||||
x = frames_fchw.detach()
|
||||
|
||||
if x.dtype != torch.uint8:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
# Heuristic: if negatives exist, assume [-1,1] normalization; else assume [0,1]
|
||||
if torch.isfinite(x).all():
|
||||
if x.min().item() < 0.0:
|
||||
x = x * 0.5 + 0.5
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x = (x * 255.0).round().to(torch.uint8)
|
||||
else:
|
||||
x = x.to(torch.uint8)
|
||||
|
||||
# [F,C,H,W] -> [F,H,W,C]
|
||||
x = x.permute(0, 2, 3, 1).contiguous().cpu()
|
||||
return x
|
||||
|
||||
|
||||
def _mux_with_ffmpeg(video_in: str, wav_in: str, mp4_out: str):
|
||||
# Copy video stream, encode audio to AAC, align to shortest
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-hide_banner",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-i",
|
||||
video_in,
|
||||
"-i",
|
||||
wav_in,
|
||||
"-c:v",
|
||||
"copy",
|
||||
"-c:a",
|
||||
"aac",
|
||||
"-shortest",
|
||||
mp4_out,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
# run through an epoch ang check sizes
|
||||
dataloader_iterator = iter(dataloader)
|
||||
idx = 0
|
||||
for epoch in range(args.epochs):
|
||||
for batch in tqdm(dataloader):
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
frames = 1
|
||||
if len(img_batch.shape) == 5:
|
||||
frames = img_batch.shape[1]
|
||||
batch_size, frames, channels, height, width = img_batch.shape
|
||||
else:
|
||||
batch_size, channels, height, width = img_batch.shape
|
||||
|
||||
# load audio
|
||||
audio_tensor = batch.audio_tensor # all file items contatinated on the batch dimension
|
||||
audio_data = batch.audio_data # list of raw audio data per item in the batch
|
||||
|
||||
# llm save the videos here with audio and video as mp4
|
||||
fps = getattr(dataset_config, "fps", None)
|
||||
if fps is None or fps <= 0:
|
||||
fps = 1.0
|
||||
|
||||
# Ensure we can iterate items even if batch_size > 1
|
||||
for b in range(batch_size):
|
||||
# Get per-item frames as [F,C,H,W]
|
||||
if len(img_batch.shape) == 5:
|
||||
frames_fchw = img_batch[b]
|
||||
else:
|
||||
# single image: [C,H,W] -> [1,C,H,W]
|
||||
frames_fchw = img_batch[b].unsqueeze(0)
|
||||
|
||||
video_uint8 = _tensor_to_uint8_video(frames_fchw)
|
||||
out_mp4 = os.path.join(args.output_path, f"{idx:06d}_{b:02d}.mp4")
|
||||
|
||||
# Pick audio for this item (prefer audio_data list; fallback to audio_tensor)
|
||||
item_audio = None
|
||||
item_sr = None
|
||||
|
||||
if isinstance(audio_data, (list, tuple)) and len(audio_data) > b:
|
||||
ad = audio_data[b]
|
||||
if isinstance(ad, dict) and ("waveform" in ad) and ("sample_rate" in ad) and ad["waveform"] is not None:
|
||||
item_audio = ad["waveform"]
|
||||
item_sr = int(ad["sample_rate"])
|
||||
elif audio_tensor is not None and torch.is_tensor(audio_tensor):
|
||||
# audio_tensor expected [B, C, L] (or [C,L] if batch collate differs)
|
||||
if audio_tensor.dim() == 3 and audio_tensor.shape[0] > b:
|
||||
item_audio = audio_tensor[b]
|
||||
elif audio_tensor.dim() == 2 and b == 0:
|
||||
item_audio = audio_tensor
|
||||
if item_audio is not None:
|
||||
# best-effort sample rate from audio_data if present but not per-item dict
|
||||
if isinstance(audio_data, dict) and "sample_rate" in audio_data:
|
||||
try:
|
||||
item_sr = int(audio_data["sample_rate"])
|
||||
except Exception:
|
||||
item_sr = None
|
||||
|
||||
# Write mp4 (with audio if available) using ffmpeg muxing (torchvision audio muxing is unreliable)
|
||||
tmp_video = out_mp4 + ".tmp_video.mp4"
|
||||
tmp_wav = out_mp4 + ".tmp_audio.wav"
|
||||
try:
|
||||
# Always write video-only first
|
||||
write_video(tmp_video, video_uint8, fps=float(fps), video_codec="libx264")
|
||||
|
||||
if item_audio is not None and item_sr is not None and item_audio.numel() > 0:
|
||||
import torchaudio
|
||||
|
||||
wav = item_audio.detach()
|
||||
# torchaudio.save expects [channels, samples]
|
||||
if wav.dim() == 1:
|
||||
wav = wav.unsqueeze(0)
|
||||
torchaudio.save(tmp_wav, wav.cpu().to(torch.float32), int(item_sr))
|
||||
|
||||
# Mux to final mp4
|
||||
_mux_with_ffmpeg(tmp_video, tmp_wav, out_mp4)
|
||||
else:
|
||||
# No audio: just move video into place
|
||||
os.replace(tmp_video, out_mp4)
|
||||
|
||||
except Exception as e:
|
||||
# Best-effort fallback: leave a playable video-only file
|
||||
try:
|
||||
if os.path.exists(tmp_video):
|
||||
os.replace(tmp_video, out_mp4)
|
||||
else:
|
||||
write_video(out_mp4, video_uint8, fps=float(fps), video_codec="libx264")
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if hasattr(dataset_config, 'debug') and dataset_config.debug:
|
||||
print(f"Warning: failed to mux audio into mp4 for {out_mp4}: {e}")
|
||||
|
||||
finally:
|
||||
# Cleanup temps (don't leave separate wavs lying around)
|
||||
try:
|
||||
if os.path.exists(tmp_video):
|
||||
os.remove(tmp_video)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if os.path.exists(tmp_wav):
|
||||
os.remove(tmp_wav)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
idx += 1
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
print('done')
|
||||
75
toolkit/audio/preserve_pitch.py
Normal file
75
toolkit/audio/preserve_pitch.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
def time_stretch_preserve_pitch(waveform: torch.Tensor, sample_rate: int, target_samples: int) -> torch.Tensor:
|
||||
"""
|
||||
waveform: [C, L] float tensor (CPU or GPU)
|
||||
returns: [C, target_samples] float tensor
|
||||
Pitch-preserving time stretch to match target_samples.
|
||||
"""
|
||||
|
||||
|
||||
if waveform.dim() == 1:
|
||||
waveform = waveform.unsqueeze(0)
|
||||
|
||||
waveform = waveform.to(torch.float32)
|
||||
|
||||
src_len = waveform.shape[-1]
|
||||
if src_len == 0 or target_samples <= 0:
|
||||
return waveform[..., :0]
|
||||
|
||||
if src_len == target_samples:
|
||||
return waveform
|
||||
|
||||
# rate > 1.0 speeds up (shorter), rate < 1.0 slows down (longer)
|
||||
rate = float(src_len) / float(target_samples)
|
||||
|
||||
# Use sample_rate to pick STFT params
|
||||
win_seconds = 0.046
|
||||
hop_seconds = 0.0115
|
||||
|
||||
n_fft_target = int(sample_rate * win_seconds)
|
||||
n_fft = 1 << max(8, int(math.floor(math.log2(max(256, n_fft_target))))) # >=256, pow2
|
||||
win_length = n_fft
|
||||
hop_length = max(64, int(sample_rate * hop_seconds))
|
||||
hop_length = min(hop_length, win_length // 2)
|
||||
|
||||
window = torch.hann_window(win_length, device=waveform.device, dtype=waveform.dtype)
|
||||
|
||||
stft = torch.stft(
|
||||
waveform,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=True,
|
||||
return_complex=True,
|
||||
) # [C, F, T] complex
|
||||
|
||||
# IMPORTANT: n_freq must match STFT's frequency bins (n_fft//2 + 1)
|
||||
stretcher = torchaudio.transforms.TimeStretch(
|
||||
n_freq=stft.shape[-2],
|
||||
hop_length=hop_length,
|
||||
fixed_rate=rate,
|
||||
).to(waveform.device)
|
||||
|
||||
stft_stretched = stretcher(stft) # [C, F, T']
|
||||
|
||||
stretched = torch.istft(
|
||||
stft_stretched,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=True,
|
||||
length=target_samples,
|
||||
)
|
||||
|
||||
if stretched.shape[-1] > target_samples:
|
||||
stretched = stretched[..., :target_samples]
|
||||
elif stretched.shape[-1] < target_samples:
|
||||
stretched = F.pad(stretched, (0, target_samples - stretched.shape[-1]))
|
||||
|
||||
return stretched
|
||||
@@ -207,6 +207,9 @@ class NetworkConfig:
|
||||
# -1 automatically finds the largest factor
|
||||
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||
|
||||
# Use the old lokr format
|
||||
self.old_lokr_format = kwargs.get('old_lokr_format', False)
|
||||
|
||||
# for multi stage models
|
||||
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
|
||||
|
||||
@@ -672,6 +675,9 @@ class ModelConfig:
|
||||
# kwargs to pass to the model
|
||||
self.model_kwargs = kwargs.get("model_kwargs", {})
|
||||
|
||||
# model paths for models that support it
|
||||
self.model_paths = kwargs.get("model_paths", {})
|
||||
|
||||
# allow frontend to pass arch with a color like arch:tag
|
||||
# but remove the tag
|
||||
if self.arch is not None:
|
||||
@@ -956,7 +962,7 @@ class DatasetConfig:
|
||||
# it will select a random start frame and pull the frames at the given fps
|
||||
# this could have various issues with shorter videos and videos with variable fps
|
||||
# I recommend trimming your videos to the desired length and using shrink_video_to_frames(default)
|
||||
self.fps: int = kwargs.get('fps', 16)
|
||||
self.fps: int = kwargs.get('fps', 24)
|
||||
|
||||
# debug the frame count and frame selection. You dont need this. It is for debugging.
|
||||
self.debug: bool = kwargs.get('debug', False)
|
||||
@@ -972,6 +978,9 @@ class DatasetConfig:
|
||||
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
|
||||
|
||||
self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable
|
||||
self.do_audio: bool = kwargs.get('do_audio', False) # load audio from video files for models that support it
|
||||
self.audio_preserve_pitch: bool = kwargs.get('audio_preserve_pitch', False) # preserve pitch when stretching audio to fit num_frames
|
||||
self.audio_normalize: bool = kwargs.get('audio_normalize', False) # normalize audio volume levels when loading
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
|
||||
@@ -123,9 +123,13 @@ class FileItemDTO(
|
||||
self.is_reg = self.dataset_config.is_reg
|
||||
self.prior_reg = self.dataset_config.prior_reg
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_text_embedding()
|
||||
self.cleanup_control()
|
||||
@@ -154,6 +158,13 @@ class DataLoaderBatchDTO:
|
||||
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None
|
||||
self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None
|
||||
self.audio_tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
# just for holding noise and preds during training
|
||||
self.audio_target: Union[torch.Tensor, None] = None
|
||||
self.audio_pred: Union[torch.Tensor, None] = None
|
||||
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
@@ -304,6 +315,21 @@ class DataLoaderBatchDTO:
|
||||
y.text_embeds = [y.text_embeds]
|
||||
prompt_embeds_list.append(y)
|
||||
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
|
||||
|
||||
if any([x.audio_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_audio_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.audio_tensor is not None:
|
||||
base_audio_tensor = x.audio_tensor
|
||||
break
|
||||
audio_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.audio_tensor is None:
|
||||
audio_tensors.append(torch.zeros_like(base_audio_tensor))
|
||||
else:
|
||||
audio_tensors.append(x.audio_tensor)
|
||||
self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors])
|
||||
|
||||
|
||||
except Exception as e:
|
||||
@@ -336,6 +362,10 @@ class DataLoaderBatchDTO:
|
||||
del self.latents
|
||||
del self.tensor
|
||||
del self.control_tensor
|
||||
del self.audio_tensor
|
||||
del self.audio_data
|
||||
del self.audio_target
|
||||
del self.audio_pred
|
||||
for file_item in self.file_items:
|
||||
file_item.cleanup()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor
|
||||
|
||||
from toolkit.audio.preserve_pitch import time_stretch_preserve_pitch
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||
from toolkit.config_modules import ControlTypes
|
||||
@@ -467,6 +468,8 @@ class ImageProcessingDTOMixin:
|
||||
if not self.dataset_config.buckets:
|
||||
raise Exception('Buckets required for video processing')
|
||||
|
||||
do_audio = self.dataset_config.do_audio
|
||||
|
||||
try:
|
||||
# Use OpenCV to capture video frames
|
||||
cap = cv2.VideoCapture(self.path)
|
||||
@@ -596,6 +599,84 @@ class ImageProcessingDTOMixin:
|
||||
|
||||
# Stack frames into tensor [frames, channels, height, width]
|
||||
self.tensor = torch.stack(frames)
|
||||
|
||||
# ------------------------------
|
||||
# Audio extraction + stretching
|
||||
# ------------------------------
|
||||
if do_audio:
|
||||
# Default to "no audio" unless we successfully extract it
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Compute the time range of the selected frames in the *source* video
|
||||
# Include the last frame by extending to the next frame boundary.
|
||||
if video_fps and video_fps > 0 and len(frames_to_extract) > 0:
|
||||
clip_start_frame = int(frames_to_extract[0])
|
||||
clip_end_frame = int(frames_to_extract[-1])
|
||||
clip_start_time = clip_start_frame / float(video_fps)
|
||||
clip_end_time = (clip_end_frame + 1) / float(video_fps)
|
||||
source_duration = max(0.0, clip_end_time - clip_start_time)
|
||||
else:
|
||||
clip_start_time = 0.0
|
||||
clip_end_time = 0.0
|
||||
source_duration = 0.0
|
||||
|
||||
# Target duration is how this sampled/stretched clip is interpreted for training
|
||||
# (i.e. num_frames at the configured dataset FPS).
|
||||
if hasattr(self.dataset_config, "fps") and self.dataset_config.fps and self.dataset_config.fps > 0:
|
||||
target_duration = float(self.dataset_config.num_frames) / float(self.dataset_config.fps)
|
||||
else:
|
||||
target_duration = source_duration
|
||||
|
||||
waveform, sample_rate = torchaudio.load(self.path) # [channels, samples]
|
||||
|
||||
if self.dataset_config.audio_normalize:
|
||||
peak = waveform.abs().amax() # global peak across channels
|
||||
eps = 1e-9
|
||||
target_peak = 0.999 # ~ -0.01 dBFS
|
||||
gain = target_peak / (peak + eps)
|
||||
waveform = waveform * gain
|
||||
|
||||
# Slice to the selected clip region (when we have a meaningful time range)
|
||||
if source_duration > 0.0:
|
||||
start_sample = int(round(clip_start_time * sample_rate))
|
||||
end_sample = int(round(clip_end_time * sample_rate))
|
||||
start_sample = max(0, min(start_sample, waveform.shape[-1]))
|
||||
end_sample = max(0, min(end_sample, waveform.shape[-1]))
|
||||
if end_sample > start_sample:
|
||||
waveform = waveform[..., start_sample:end_sample]
|
||||
else:
|
||||
# No valid audio segment
|
||||
waveform = None
|
||||
else:
|
||||
# If we can't compute a meaningful time range, treat as no-audio
|
||||
waveform = None
|
||||
|
||||
if waveform is not None and waveform.numel() > 0:
|
||||
target_samples = int(round(target_duration * sample_rate))
|
||||
if target_samples > 0 and waveform.shape[-1] != target_samples:
|
||||
# Time-stretch/shrink to match the video clip duration implied by dataset FPS.
|
||||
if self.dataset_config.audio_preserve_pitch:
|
||||
waveform = time_stretch_preserve_pitch(waveform, sample_rate, target_samples) # waveform is [C, L]
|
||||
else:
|
||||
# Use linear interpolation over the time axis.
|
||||
wf = waveform.unsqueeze(0) # [1, C, L]
|
||||
wf = F.interpolate(wf, size=target_samples, mode="linear", align_corners=False)
|
||||
waveform = wf.squeeze(0) # [C, L]
|
||||
|
||||
self.audio_tensor = waveform
|
||||
self.audio_data = {"waveform": waveform, "sample_rate": int(sample_rate)}
|
||||
|
||||
except Exception as e:
|
||||
# Keep behavior identical for non-audio datasets; for audio datasets, just skip if missing/broken.
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
print_acc(f"Could not extract/stretch audio for {self.path}: {e}")
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
|
||||
# Only log success in debug mode
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
|
||||
@@ -265,11 +265,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.peft_format = peft_format
|
||||
self.is_transformer = is_transformer
|
||||
|
||||
# use the old format for older models unless the user has specified otherwise
|
||||
self.use_old_lokr_format = False
|
||||
if self.network_config is not None and hasattr(self.network_config, 'old_lokr_format'):
|
||||
self.use_old_lokr_format = self.network_config.old_lokr_format
|
||||
# also allow a false from the model itself
|
||||
if base_model is not None and not base_model.use_old_lokr_format:
|
||||
self.use_old_lokr_format = False
|
||||
|
||||
# always do peft for flux only for now
|
||||
if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer:
|
||||
# don't do peft format for lokr
|
||||
if self.network_type.lower() != "lokr":
|
||||
# don't do peft format for lokr if using old format
|
||||
if self.network_type.lower() != "lokr" or not self.use_old_lokr_format:
|
||||
self.peft_format = True
|
||||
|
||||
if self.peft_format:
|
||||
|
||||
@@ -185,6 +185,11 @@ class BaseModel:
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
# defines if the model supports model paths. Only some will
|
||||
self.supports_model_paths = False
|
||||
|
||||
# use new lokr format (default false for old models for backwards compatibility)
|
||||
self.use_old_lokr_format = True
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
@@ -11,6 +11,7 @@ from toolkit.network_mixins import ToolkitModuleMixin
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
|
||||
from optimum.quanto import QBytesTensor, QTensor
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -284,17 +285,26 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
org_sd[weight_key] = merged_weight.to(orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def get_orig_weight(self):
|
||||
def get_orig_weight(self, device):
|
||||
weight = self.org_module[0].weight
|
||||
if weight.device != device:
|
||||
weight = weight.to(device)
|
||||
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
|
||||
return weight.dequantize().data.detach()
|
||||
elif isinstance(weight, AffineQuantizedTensor):
|
||||
return weight.dequantize().data.detach()
|
||||
else:
|
||||
return weight.data.detach()
|
||||
|
||||
def get_orig_bias(self):
|
||||
def get_orig_bias(self, device):
|
||||
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
|
||||
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
|
||||
return self.org_module[0].bias.dequantize().data.detach()
|
||||
bias = self.org_module[0].bias
|
||||
if bias.device != device:
|
||||
bias = bias.to(device)
|
||||
if isinstance(bias, QTensor) or isinstance(bias, QBytesTensor):
|
||||
return bias.dequantize().data.detach()
|
||||
elif isinstance(bias, AffineQuantizedTensor):
|
||||
return bias.dequantize().data.detach()
|
||||
else:
|
||||
return self.org_module[0].bias.data.detach()
|
||||
return None
|
||||
@@ -305,7 +315,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
|
||||
orig_dtype = x.dtype
|
||||
|
||||
orig_weight = self.get_orig_weight()
|
||||
orig_weight = self.get_orig_weight(x.device)
|
||||
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
|
||||
@@ -319,7 +329,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
orig_weight
|
||||
+ lokr_weight * multiplier
|
||||
)
|
||||
bias = self.get_orig_bias()
|
||||
bias = self.get_orig_bias(x.device)
|
||||
if bias is not None:
|
||||
bias = bias.to(weight.device, dtype=weight.dtype)
|
||||
output = self.op(
|
||||
|
||||
@@ -546,7 +546,8 @@ class ToolkitNetworkMixin:
|
||||
|
||||
new_save_dict = {}
|
||||
for key, value in save_dict.items():
|
||||
if key.endswith('.alpha'):
|
||||
# lokr needs alpha
|
||||
if key.endswith('.alpha') and self.network_type.lower() != "lokr":
|
||||
continue
|
||||
new_key = key
|
||||
new_key = new_key.replace('lora_down', 'lora_A')
|
||||
@@ -558,7 +559,7 @@ class ToolkitNetworkMixin:
|
||||
save_dict = new_save_dict
|
||||
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
if self.network_type.lower() == "lokr" and self.use_old_lokr_format:
|
||||
new_save_dict = {}
|
||||
for key, value in save_dict.items():
|
||||
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||
@@ -632,7 +633,7 @@ class ToolkitNetworkMixin:
|
||||
# lora_down = lora_A
|
||||
# lora_up = lora_B
|
||||
# no alpha
|
||||
if load_key.endswith('.alpha'):
|
||||
if load_key.endswith('.alpha') and self.network_type.lower() != "lokr":
|
||||
continue
|
||||
load_key = load_key.replace('lora_A', 'lora_down')
|
||||
load_key = load_key.replace('lora_B', 'lora_up')
|
||||
@@ -640,6 +641,13 @@ class ToolkitNetworkMixin:
|
||||
load_key = load_key.replace('.', '$$')
|
||||
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
|
||||
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
|
||||
|
||||
# patch lokr, not sure why we need to but whatever
|
||||
if self.network_type.lower() == "lokr":
|
||||
load_key = load_key.replace('$$lokr_w1', '.lokr_w1')
|
||||
load_key = load_key.replace('$$lokr_w2', '.lokr_w2')
|
||||
if load_key.endswith('$$alpha'):
|
||||
load_key = load_key[:-7] + '.alpha'
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||
|
||||
@@ -223,6 +223,11 @@ class StableDiffusion:
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
# defines if the model supports model paths. Only some will
|
||||
self.supports_model_paths = False
|
||||
|
||||
# use new lokr format (default false for old models for backwards compatibility)
|
||||
self.use_old_lokr_format = True
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
@@ -15,7 +15,7 @@ export async function POST(request: Request) {
|
||||
}
|
||||
|
||||
// make sure it is an image
|
||||
if (!/\.(jpg|jpeg|png|bmp|gif|tiff|webp)$/i.test(imgPath.toLowerCase())) {
|
||||
if (!/\.(jpg|jpeg|png|bmp|gif|tiff|webp|mp4)$/i.test(imgPath.toLowerCase())) {
|
||||
return NextResponse.json({ error: 'Not an image' }, { status: 400 });
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
const samples = fs
|
||||
.readdirSync(samplesFolder)
|
||||
.filter(file => {
|
||||
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp');
|
||||
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp') || file.endsWith('.mp4');
|
||||
})
|
||||
.map(file => {
|
||||
return path.join(samplesFolder, file);
|
||||
|
||||
@@ -862,6 +862,48 @@ export default function SimpleJob({
|
||||
docKey="datasets.do_i2v"
|
||||
/>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('datasets.do_audio') && (
|
||||
<Checkbox
|
||||
label="Do Audio"
|
||||
checked={dataset.do_audio || false}
|
||||
onChange={value => {
|
||||
if (!value) {
|
||||
setJobConfig(undefined, `config.process[0].datasets[${i}].do_audio`);
|
||||
} else {
|
||||
setJobConfig(value, `config.process[0].datasets[${i}].do_audio`);
|
||||
}
|
||||
}}
|
||||
docKey="datasets.do_audio"
|
||||
/>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('datasets.audio_normalize') && (
|
||||
<Checkbox
|
||||
label="Audio Normalize"
|
||||
checked={dataset.audio_normalize || false}
|
||||
onChange={value => {
|
||||
if (!value) {
|
||||
setJobConfig(undefined, `config.process[0].datasets[${i}].audio_normalize`);
|
||||
} else {
|
||||
setJobConfig(value, `config.process[0].datasets[${i}].audio_normalize`);
|
||||
}
|
||||
}}
|
||||
docKey="datasets.audio_normalize"
|
||||
/>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('datasets.audio_preserve_pitch') && (
|
||||
<Checkbox
|
||||
label="Audio Preserve Pitch"
|
||||
checked={dataset.audio_preserve_pitch || false}
|
||||
onChange={value => {
|
||||
if (!value) {
|
||||
setJobConfig(undefined, `config.process[0].datasets[${i}].audio_preserve_pitch`);
|
||||
} else {
|
||||
setJobConfig(value, `config.process[0].datasets[${i}].audio_preserve_pitch`);
|
||||
}
|
||||
}}
|
||||
docKey="datasets.audio_preserve_pitch"
|
||||
/>
|
||||
)}
|
||||
</FormGroup>
|
||||
<FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2">
|
||||
<Checkbox
|
||||
|
||||
@@ -14,7 +14,6 @@ export const defaultDatasetConfig: DatasetConfig = {
|
||||
controls: [],
|
||||
shrink_video_to_frames: true,
|
||||
num_frames: 1,
|
||||
do_i2v: true,
|
||||
flip_x: false,
|
||||
flip_y: false,
|
||||
};
|
||||
|
||||
@@ -17,6 +17,9 @@ type AdditionalSections =
|
||||
| 'datasets.control_path'
|
||||
| 'datasets.multi_control_paths'
|
||||
| 'datasets.do_i2v'
|
||||
| 'datasets.do_audio'
|
||||
| 'datasets.audio_normalize'
|
||||
| 'datasets.audio_preserve_pitch'
|
||||
| 'sample.ctrl_img'
|
||||
| 'sample.multi_ctrl_imgs'
|
||||
| 'datasets.num_frames'
|
||||
@@ -288,6 +291,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.width': [768, 1024],
|
||||
'config.process[0].sample.height': [768, 1024],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
'config.process[0].datasets[x].do_i2v': [true, undefined],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'],
|
||||
@@ -601,6 +605,31 @@ export const modelArchs: ModelArch[] = [
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['model.low_vram', 'model.layer_offloading'],
|
||||
},
|
||||
{
|
||||
name: 'ltx2',
|
||||
label: 'LTX-2',
|
||||
group: 'video',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['Lightricks/LTX-2', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.low_vram': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [121, 1],
|
||||
'config.process[0].sample.fps': [24, 1],
|
||||
'config.process[0].sample.width': [768, 1024],
|
||||
'config.process[0].sample.height': [768, 1024],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
'config.process[0].datasets[x].do_i2v': [false, undefined],
|
||||
'config.process[0].datasets[x].do_audio': [true, undefined],
|
||||
'config.process[0].datasets[x].fps': [24, undefined],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch'],
|
||||
},
|
||||
].sort((a, b) => {
|
||||
// Sort by label, case-insensitive
|
||||
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });
|
||||
|
||||
@@ -2,6 +2,25 @@ import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
||||
import { modelArchs, ModelArch } from './options';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
|
||||
const expandDatasetDefaults = (
|
||||
defaults: { [key: string]: any },
|
||||
numDatasets: number,
|
||||
): { [key: string]: any } => {
|
||||
// expands the defaults for datasets[x] to datasets[0], datasets[1], etc.
|
||||
const expandedDefaults: { [key: string]: any } = { ...defaults };
|
||||
for (const key in defaults) {
|
||||
if (key.includes('datasets[x].')) {
|
||||
for (let i = 0; i < numDatasets; i++) {
|
||||
const datasetKey = key.replace('datasets[x].', `datasets[${i}].`);
|
||||
const v = defaults[key];
|
||||
expandedDefaults[datasetKey] = Array.isArray(v) ? [...v] : objectCopy(v);
|
||||
}
|
||||
delete expandedDefaults[key];
|
||||
}
|
||||
}
|
||||
return expandedDefaults;
|
||||
};
|
||||
|
||||
export const handleModelArchChange = (
|
||||
currentArchName: string,
|
||||
newArchName: string,
|
||||
@@ -39,16 +58,11 @@ export const handleModelArchChange = (
|
||||
}
|
||||
}
|
||||
|
||||
// revert defaults from previous model
|
||||
for (const key in currentArch.defaults) {
|
||||
setJobConfig(currentArch.defaults[key][1], key);
|
||||
}
|
||||
const numDatasets = jobConfig.config.process[0].datasets.length;
|
||||
|
||||
let currentDefaults = expandDatasetDefaults(currentArch.defaults || {}, numDatasets);
|
||||
let newDefaults = expandDatasetDefaults(newArch?.defaults || {}, numDatasets);
|
||||
|
||||
if (newArch?.defaults) {
|
||||
for (const key in newArch.defaults) {
|
||||
setJobConfig(newArch.defaults[key][0], key);
|
||||
}
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(newArchName, 'config.process[0].model.arch');
|
||||
|
||||
@@ -79,27 +93,27 @@ export const handleModelArchChange = (
|
||||
if (newDataset.control_path_1 && newDataset.control_path_1 !== '') {
|
||||
newDataset.control_path = newDataset.control_path_1;
|
||||
}
|
||||
if (newDataset.control_path_1) {
|
||||
if ('control_path_1' in newDataset) {
|
||||
delete newDataset.control_path_1;
|
||||
}
|
||||
if (newDataset.control_path_2) {
|
||||
if ('control_path_2' in newDataset) {
|
||||
delete newDataset.control_path_2;
|
||||
}
|
||||
if (newDataset.control_path_3) {
|
||||
if ('control_path_3' in newDataset) {
|
||||
delete newDataset.control_path_3;
|
||||
}
|
||||
} else {
|
||||
// does not have control images
|
||||
if (newDataset.control_path) {
|
||||
if ('control_path' in newDataset) {
|
||||
delete newDataset.control_path;
|
||||
}
|
||||
if (newDataset.control_path_1) {
|
||||
if ('control_path_1' in newDataset) {
|
||||
delete newDataset.control_path_1;
|
||||
}
|
||||
if (newDataset.control_path_2) {
|
||||
if ('control_path_2' in newDataset) {
|
||||
delete newDataset.control_path_2;
|
||||
}
|
||||
if (newDataset.control_path_3) {
|
||||
if ('control_path_3' in newDataset) {
|
||||
delete newDataset.control_path_3;
|
||||
}
|
||||
}
|
||||
@@ -120,4 +134,13 @@ export const handleModelArchChange = (
|
||||
return newSample;
|
||||
});
|
||||
setJobConfig(samples, 'config.process[0].sample.samples');
|
||||
|
||||
// revert defaults from previous model
|
||||
for (const key in currentDefaults) {
|
||||
setJobConfig(currentDefaults[key][1], key);
|
||||
}
|
||||
|
||||
for (const key in newDefaults) {
|
||||
setJobConfig(newDefaults[key][0], key);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -56,18 +56,6 @@ const SampleImageCard: React.FC<SampleImageCardProps> = ({
|
||||
return () => observer.disconnect();
|
||||
}, [observerRoot, rootMargin]);
|
||||
|
||||
// Pause video when leaving viewport
|
||||
useEffect(() => {
|
||||
if (!isVideo(imageUrl)) return;
|
||||
const v = videoRef.current;
|
||||
if (!v) return;
|
||||
if (!isVisible && !v.paused) {
|
||||
try {
|
||||
v.pause();
|
||||
} catch {}
|
||||
}
|
||||
}, [isVisible, imageUrl]);
|
||||
|
||||
const handleLoad = () => setLoaded(true);
|
||||
|
||||
return (
|
||||
@@ -81,9 +69,11 @@ const SampleImageCard: React.FC<SampleImageCardProps> = ({
|
||||
src={`/api/img/${encodeURIComponent(imageUrl)}`}
|
||||
className="w-full h-full object-cover"
|
||||
preload="none"
|
||||
onLoad={handleLoad}
|
||||
playsInline
|
||||
muted
|
||||
loop
|
||||
autoPlay
|
||||
controls={false}
|
||||
/>
|
||||
) : (
|
||||
|
||||
@@ -7,6 +7,7 @@ import { Cog } from 'lucide-react';
|
||||
import { Menu, MenuButton, MenuItem, MenuItems } from '@headlessui/react';
|
||||
import { openConfirm } from './ConfirmModal';
|
||||
import { apiClient } from '@/utils/api';
|
||||
import { isVideo } from '@/utils/basic';
|
||||
|
||||
interface Props {
|
||||
imgPath: string | null; // current image path
|
||||
@@ -200,13 +201,24 @@ export default function SampleImageViewer({
|
||||
className="relative transform rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in max-w-[95%] max-h-[95vh] data-closed:sm:translate-y-0 data-closed:sm:scale-95 flex flex-col overflow-hidden"
|
||||
>
|
||||
<div className="overflow-hidden flex items-center justify-center">
|
||||
{imgPath && (
|
||||
<img
|
||||
src={`/api/img/${encodeURIComponent(imgPath)}`}
|
||||
alt="Sample Image"
|
||||
className="w-auto h-auto max-w-[95vw] max-h-[82vh] object-contain"
|
||||
/>
|
||||
)}
|
||||
{imgPath &&
|
||||
(isVideo(imgPath) ? (
|
||||
<video
|
||||
src={`/api/img/${encodeURIComponent(imgPath)}`}
|
||||
className="w-auto h-auto max-w-[95vw] max-h-[82vh] object-contain"
|
||||
preload="none"
|
||||
playsInline
|
||||
loop
|
||||
autoPlay
|
||||
controls={true}
|
||||
/>
|
||||
) : (
|
||||
<img
|
||||
src={`/api/img/${encodeURIComponent(imgPath)}`}
|
||||
alt="Sample Image"
|
||||
className="w-auto h-auto max-w-[95vw] max-h-[82vh] object-contain"
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
{/* # make full width */}
|
||||
<div className="bg-gray-950 text-sm flex justify-between items-center px-4 py-2">
|
||||
|
||||
@@ -107,6 +107,35 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'datasets.do_audio': {
|
||||
title: 'Do Audio',
|
||||
description: (
|
||||
<>
|
||||
For models that support audio with video, this option will load the audio from the video and resize it to match
|
||||
the video sequence. Since the video is automatically resized, the audio may drop or raise in pitch to match the new
|
||||
speed of the video. It is important to prep your dataset to have the proper length before training.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'datasets.audio_normalize': {
|
||||
title: 'Audio Normalize',
|
||||
description: (
|
||||
<>
|
||||
When loading audio, this will normalize the audio volume to the max peaks. Useful if your dataset has varying audio
|
||||
volumes. Warning, do not use if you have clips with full silence you want to keep, as it will raise the volume of those clips.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'datasets.audio_preserve_pitch': {
|
||||
title: 'Audio Preserve Pitch',
|
||||
description: (
|
||||
<>
|
||||
When loading audio to match the number of frames requested, this option will preserve the pitch of the audio if
|
||||
the length does not match training target. It is recommended to have a dataset that matches your target length,
|
||||
as this option can add sound distortions.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'datasets.flip': {
|
||||
title: 'Flip X and Flip Y',
|
||||
description: (
|
||||
|
||||
@@ -96,7 +96,11 @@ export interface DatasetConfig {
|
||||
control_path?: string | null;
|
||||
num_frames: number;
|
||||
shrink_video_to_frames: boolean;
|
||||
do_i2v: boolean;
|
||||
do_i2v?: boolean;
|
||||
do_audio?: boolean;
|
||||
audio_normalize?: boolean;
|
||||
audio_preserve_pitch?: boolean;
|
||||
fps?: number;
|
||||
flip_x: boolean;
|
||||
flip_y: boolean;
|
||||
control_path_1?: string | null;
|
||||
|
||||
@@ -17,21 +17,14 @@ export function setNestedValue<T, V>(obj: T, value: V, path?: string): T {
|
||||
}
|
||||
|
||||
// Split the path into segments
|
||||
const pathArray = path.split('.').flatMap(segment => {
|
||||
// Handle array notation like 'process[0]'
|
||||
const arrayMatch = segment.match(/^([^\[]+)(\[\d+\])+/);
|
||||
if (arrayMatch) {
|
||||
const propName = arrayMatch[1];
|
||||
const indices = segment
|
||||
.substring(propName.length)
|
||||
.match(/\[(\d+)\]/g)
|
||||
?.map(idx => parseInt(idx.substring(1, idx.length - 1)));
|
||||
const pathArray: Array<string | number> = [];
|
||||
const re = /([^[.\]]+)|\[(\d+)\]/g;
|
||||
let m: RegExpExecArray | null;
|
||||
|
||||
// Return property name followed by array indices
|
||||
return [propName, ...(indices || [])];
|
||||
}
|
||||
return segment;
|
||||
});
|
||||
while ((m = re.exec(path)) !== null) {
|
||||
if (m[1] !== undefined) pathArray.push(m[1]);
|
||||
else pathArray.push(Number(m[2]));
|
||||
}
|
||||
|
||||
// Navigate to the target location
|
||||
let current: any = result;
|
||||
@@ -43,8 +36,18 @@ export function setNestedValue<T, V>(obj: T, value: V, path?: string): T {
|
||||
if (!Array.isArray(current)) {
|
||||
throw new Error(`Cannot access index ${key} of non-array`);
|
||||
}
|
||||
// Create a copy of the array to maintain immutability
|
||||
current = [...current];
|
||||
|
||||
// Ensure the indexed element exists and is copied/created immutably
|
||||
const nextKey = pathArray[i + 1];
|
||||
const existing = current[key];
|
||||
|
||||
if (existing === undefined) {
|
||||
current[key] = typeof nextKey === 'number' ? [] : {};
|
||||
} else if (Array.isArray(existing)) {
|
||||
current[key] = [...existing];
|
||||
} else if (typeof existing === 'object' && existing !== null) {
|
||||
current[key] = { ...existing };
|
||||
} // else: primitives stay as-is
|
||||
} else {
|
||||
// For object properties, create a new object if it doesn't exist
|
||||
if (current[key] === undefined) {
|
||||
@@ -63,7 +66,11 @@ export function setNestedValue<T, V>(obj: T, value: V, path?: string): T {
|
||||
|
||||
// Set the value at the final path segment
|
||||
const finalKey = pathArray[pathArray.length - 1];
|
||||
current[finalKey] = value;
|
||||
if (value === undefined) {
|
||||
delete current[finalKey];
|
||||
} else {
|
||||
current[finalKey] = value;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.7.16"
|
||||
VERSION = "0.7.17"
|
||||
|
||||
Reference in New Issue
Block a user