Add LTX-2 Support (#644)

* WIP, adding support for LTX2

* Training on images working

* Fix loading comfy models

* Handle converting and deconverting lora so it matches original format

* Reworked ui to habdle ltx and propert dataset default overwriting.

* Update the way lokr saves to it is more compatable with comfy

* Audio loading and synchronization/resampling is working

* Add audio to training. Does it work? Maybe, still testing.

* Fixed fps default issue for sound

* Have ui set fps for accurate audio mapping on ltx

* Added audio procession options to the ui for ltx

* Clean up requirements
This commit is contained in:
Jaret Burkett
2026-01-13 04:55:30 -07:00
committed by GitHub
parent 6870ab490f
commit 5b5aadadb8
28 changed files with 2180 additions and 71 deletions

View File

@@ -7,6 +7,7 @@ from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
from .flux2 import Flux2Model
from .z_image import ZImageModel
from .ltx2 import LTX2Model
AI_TOOLKIT_MODELS = [
# put a list of models here
@@ -25,4 +26,5 @@ AI_TOOLKIT_MODELS = [
QwenImageEditPlusModel,
Flux2Model,
ZImageModel,
LTX2Model,
]

View File

@@ -0,0 +1 @@
from .ltx2 import LTX2Model

View File

@@ -0,0 +1,648 @@
# ref https://github.com/huggingface/diffusers/blob/249ae1f853be8775c7d0b3a26ca4fbc4f6aa9998/scripts/convert_ltx2_to_diffusers.py
from contextlib import nullcontext
from typing import Any, Dict, Tuple
import torch
from accelerate import init_empty_weights
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available() else nullcontext
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulatin Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_VOCODER_RENAME_DICT = {
"ups": "upsamplers",
"resblocks": "resnets",
"conv_pre": "conv_in",
"conv_post": "conv_out",
}
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(
state_dict: Dict[str, Any], old_key: str, new_key: str
) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
state_dict.pop(key)
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_ltx2_audio_vae_per_channel_statistics(
key: str, state_dict: Dict[str, Any]
) -> None:
if key.startswith("per_channel_statistics"):
new_key = ".".join(["decoder", key])
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
def split_transformer_and_connector_state_dict(
state_dict: Dict[str, Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
connector_prefixes = (
"video_embeddings_connector",
"audio_embeddings_connector",
"transformer_1d_blocks",
"text_embedding_projection.aggregate_embed",
"connectors.",
"video_connector",
"audio_connector",
"text_proj_in",
)
transformer_state_dict, connector_state_dict = {}, {}
for key, value in state_dict.items():
if key.startswith(connector_prefixes):
connector_state_dict[key] = value
else:
transformer_state_dict[key] = value
return transformer_state_dict, connector_state_dict
def get_ltx2_transformer_config() -> Tuple[
Dict[str, Any], Dict[str, Any], Dict[str, Any]
]:
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 128,
"audio_out_channels": 128,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 32,
"audio_attention_head_dim": 64,
"audio_cross_attention_dim": 2048,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 3840,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1000,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def get_ltx2_connectors_config() -> Tuple[
Dict[str, Any], Dict[str, Any], Dict[str, Any]
]:
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
return config, rename_dict, special_keys_remap
def convert_ltx2_transformer(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_transformer_config()
diffusers_config = config["diffusers_config"]
transformer_state_dict, _ = split_transformer_and_connector_state_dict(
original_state_dict
)
with init_empty_weights():
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(transformer_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(transformer_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(transformer_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, transformer_state_dict)
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
return transformer
def convert_ltx2_connectors(original_state_dict: Dict[str, Any]) -> LTX2TextConnectors:
config, rename_dict, special_keys_remap = get_ltx2_connectors_config()
diffusers_config = config["diffusers_config"]
_, connector_state_dict = split_transformer_and_connector_state_dict(
original_state_dict
)
if len(connector_state_dict) == 0:
raise ValueError("No connector weights found in the provided state dict.")
with init_empty_weights():
connectors = LTX2TextConnectors.from_config(diffusers_config)
for key in list(connector_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(connector_state_dict, key, new_key)
for key in list(connector_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, connector_state_dict)
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
return connectors
def get_ltx2_video_vae_config() -> Tuple[
Dict[str, Any], Dict[str, Any], Dict[str, Any]
]:
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": (
"spatial",
"temporal",
"spatiotemporal",
"spatiotemporal",
),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config()
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_audio_vae_config() -> Tuple[
Dict[str, Any], Dict[str, Any], Dict[str, Any]
]:
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"base_channels": 128,
"output_channels": 2,
"ch_mult": (1, 2, 4),
"num_res_blocks": 2,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 256,
"latent_channels": 8,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 64,
"double_z": True,
},
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config()
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_vocoder_config() -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"hidden_channels": 1024,
"out_channels": 2,
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_factors": [6, 5, 2, 2, 2],
"resnet_kernel_sizes": [3, 7, 11],
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 24000,
},
}
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config()
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vocoder = LTX2Vocoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
return vocoder
def get_model_state_dict_from_combined_ckpt(
combined_ckpt: Dict[str, Any], prefix: str
) -> Dict[str, Any]:
# Ensure that the key prefix ends with a dot (.)
if not prefix.endswith("."):
prefix = prefix + "."
model_state_dict = {}
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
model_state_dict[param_name.replace(prefix, "")] = param
if prefix == "model.diffusion_model.":
# Some checkpoints store the text connector projection outside the diffusion model prefix.
connector_key = "text_embedding_projection.aggregate_embed.weight"
if connector_key in combined_ckpt and connector_key not in model_state_dict:
model_state_dict[connector_key] = combined_ckpt[connector_key]
return model_state_dict
def dequantize_state_dict(state_dict: Dict[str, Any]):
keys = list(state_dict.keys())
state_out = {}
for k in keys:
if k.endswith(
(".weight_scale", ".weight_scale_2", ".pre_quant_scale", ".input_scale")
):
continue
t = state_dict[k]
if k.endswith(".weight"):
prefix = k[: -len(".weight")]
wscale_k = prefix + ".weight_scale"
if wscale_k in state_dict:
w_q = t
w_scale = state_dict[wscale_k]
# Comfy quant = absmax per-tensor weight quant, nothing fancy
w_bf16 = w_q.to(torch.bfloat16) * w_scale.to(torch.bfloat16)
state_out[k] = w_bf16
continue
state_out[k] = t
return state_out
def convert_comfy_gemma3_to_transformers(sd: dict):
out = {}
sd = dequantize_state_dict(sd)
for k, v in sd.items():
nk = k
# Vision tower weights: checkpoint has "vision_model.*"
# model expects "model.vision_tower.vision_model.*"
if k.startswith("vision_model."):
nk = "model.vision_tower." + k
# MM projector: checkpoint has "multi_modal_projector.*"
# model expects "model.multi_modal_projector.*"
elif k.startswith("multi_modal_projector."):
nk = "model." + k
# Language model: checkpoint has "model.layers.*", "model.embed_tokens.*", "model.norm.*"
# model expects "model.language_model.layers.*", etc.
elif k == "model.embed_tokens.weight":
nk = "model.language_model.embed_tokens.weight"
elif k.startswith("model.layers."):
nk = "model.language_model.layers." + k[len("model.layers.") :]
elif k.startswith("model.norm."):
nk = "model.language_model.norm." + k[len("model.norm.") :]
# (optional) common DDP prefix
if nk.startswith("module."):
nk = nk[len("module.") :]
# skip spiece_model
if nk == "spiece_model":
continue
out[nk] = v
# If lm_head is missing but embeddings exist, many Gemma-family models tie these weights.
# Add it so strict loading won't complain (or just load strict=False and call tie_weights()).
if (
"lm_head.weight" not in out
and "model.language_model.embed_tokens.weight" in out
):
out["lm_head.weight"] = out["model.language_model.embed_tokens.weight"]
return out
def convert_lora_original_to_diffusers(
lora_state_dict: Dict[str, Any],
) -> Dict[str, Any]:
out: Dict[str, Any] = {}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
for k, v in lora_state_dict.items():
# Keep the "diffusion_model." prefix as-is, but apply the transformer remaps to the rest
prefix = ""
rest = k
if rest.startswith("diffusion_model."):
prefix = "diffusion_model."
rest = rest[len(prefix) :]
nk = rest
# Same simple 1:1 remaps as the transformer
for replace_key, rename_key in rename_dict.items():
nk = nk.replace(replace_key, rename_key)
# Same special-case remap as the transformer (applies to LoRA keys too)
if nk.startswith("adaln_single."):
nk = nk.replace("adaln_single.", "time_embed.", 1)
elif nk.startswith("audio_adaln_single."):
nk = nk.replace("audio_adaln_single.", "audio_time_embed.", 1)
out[prefix + nk] = v
return out
def convert_lora_diffusers_to_original(
lora_state_dict: Dict[str, Any],
) -> Dict[str, Any]:
out: Dict[str, Any] = {}
inv_rename = {v: k for k, v in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items()}
inv_items = sorted(inv_rename.items(), key=lambda kv: len(kv[0]), reverse=True)
for k, v in lora_state_dict.items():
# Keep the "diffusion_model." prefix as-is, but invert remaps on the rest
prefix = ""
rest = k
if rest.startswith("diffusion_model."):
prefix = "diffusion_model."
rest = rest[len(prefix) :]
nk = rest
# Inverse of the adaln_single special-case
if nk.startswith("time_embed."):
nk = nk.replace("time_embed.", "adaln_single.", 1)
elif nk.startswith("audio_time_embed."):
nk = nk.replace("audio_time_embed.", "audio_adaln_single.", 1)
# Inverse 1:1 remaps
for diffusers_key, original_key in inv_items:
nk = nk.replace(diffusers_key, original_key)
out[prefix + nk] = v
return out

View File

@@ -0,0 +1,852 @@
from functools import partial
import os
from typing import List, Optional
import torch
import torchaudio
from transformers import Gemma3Config
import yaml
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from accelerate import init_empty_weights
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze
from toolkit.util.quantize import quantize, get_qtype, quantize_model
from toolkit.memory_management import MemoryManager
from safetensors.torch import load_file
try:
from diffusers import LTX2Pipeline
from diffusers.models.autoencoders import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
)
from diffusers.models.transformers import LTX2VideoTransformer3DModel
from diffusers.pipelines.ltx2.export_utils import encode_video
from transformers import (
Gemma3ForConditionalGeneration,
GemmaTokenizerFast,
)
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
from .convert_ltx2_to_diffusers import (
get_model_state_dict_from_combined_ckpt,
convert_ltx2_transformer,
convert_ltx2_video_vae,
convert_ltx2_audio_vae,
convert_ltx2_vocoder,
convert_ltx2_connectors,
dequantize_state_dict,
convert_comfy_gemma3_to_transformers,
convert_lora_original_to_diffusers,
convert_lora_diffusers_to_original,
)
except ImportError as e:
print("Diffusers import error:", e)
raise ImportError(
"Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt"
)
scheduler_config = {
"base_image_seq_len": 1024,
"base_shift": 0.95,
"invert_sigmas": False,
"max_image_seq_len": 4096,
"max_shift": 2.05,
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": 0.1,
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
dit_prefix = "model.diffusion_model."
vae_prefix = "vae."
audio_vae_prefix = "audio_vae."
vocoder_prefix = "vocoder."
def new_save_image_function(
self: GenerateImageConfig,
image, # will contain a dict that can be dumped ditectly into encode_video, just add output_path to it.
count: int = 0,
max_count: int = 0,
**kwargs,
):
# this replaces gen image config save image function so we can save the video with sound from ltx2
image["output_path"] = self.get_image_path(count, max_count)
# make sample directory if it does not exist
os.makedirs(os.path.dirname(image["output_path"]), exist_ok=True)
encode_video(**image)
flush()
def blank_log_image_function(self, *args, **kwargs):
# todo handle wandb logging of videos with audio
return
class AudioProcessor(torch.nn.Module):
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
def __init__(
self,
sample_rate: int,
mel_bins: int,
mel_hop_length: int,
n_fft: int,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=n_fft,
hop_length=mel_hop_length,
f_min=0.0,
f_max=sample_rate / 2.0,
n_mels=mel_bins,
window_fn=torch.hann_window,
center=True,
pad_mode="reflect",
power=1.0,
mel_scale="slaney",
norm="slaney",
)
def resample_waveform(
self,
waveform: torch.Tensor,
source_rate: int,
target_rate: int,
) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(device=waveform.device, dtype=waveform.dtype)
def waveform_to_mel(
self,
waveform: torch.Tensor,
waveform_sample_rate: int,
) -> torch.Tensor:
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
waveform = self.resample_waveform(
waveform, waveform_sample_rate, self.sample_rate
)
mel = self.mel_transform(waveform)
mel = torch.log(torch.clamp(mel, min=1e-5))
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
return mel.permute(0, 1, 3, 2).contiguous()
class LTX2Model(BaseModel):
arch = "ltx2"
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ["LTX2VideoTransformer3DModel"]
# defines if the model supports model paths. Only some will
self.supports_model_paths = True
# use the new format on this new model by default
self.use_old_lokr_format = False
self.audio_processor = None
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
return 32
def load_model(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading LTX2 model")
model_path = self.model_config.name_or_path
base_model_path = self.model_config.extras_name_or_path
combined_state_dict = None
self.print_and_status_update("Loading transformer")
# if we have a safetensors file it is a mono checkpoint
if os.path.exists(model_path) and model_path.endswith(".safetensors"):
combined_state_dict = load_file(model_path)
combined_state_dict = dequantize_state_dict(combined_state_dict)
if combined_state_dict is not None:
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, dit_prefix
)
transformer = convert_ltx2_transformer(original_dit_ckpt)
transformer = transformer.to(dtype)
else:
transformer_path = model_path
transformer_subfolder = "transformer"
if os.path.exists(transformer_path):
transformer_subfolder = None
transformer_path = os.path.join(transformer_path, "transformer")
# check if the path is a full checkpoint.
te_folder_path = os.path.join(model_path, "text_encoder")
# if we have the te, this folder is a full checkpoint, use it as the base
if os.path.exists(te_folder_path):
base_model_path = model_path
transformer = LTX2VideoTransformer3DModel.from_pretrained(
transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype
)
if self.model_config.quantize:
self.print_and_status_update("Quantizing Transformer")
quantize_model(self, transformer)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_transformer_percent > 0
):
ignore_modules = []
for block in transformer.transformer_blocks:
ignore_modules.append(block.scale_shift_table)
ignore_modules.append(block.audio_scale_shift_table)
ignore_modules.append(block.video_a2v_cross_attn_scale_shift_table)
ignore_modules.append(block.audio_a2v_cross_attn_scale_shift_table)
ignore_modules.append(transformer.scale_shift_table)
ignore_modules.append(transformer.audio_scale_shift_table)
MemoryManager.attach(
transformer,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent,
ignore_modules=ignore_modules,
)
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer to CPU")
transformer.to("cpu")
flush()
self.print_and_status_update("Loading text encoder")
if (
self.model_config.te_name_or_path is not None
and self.model_config.te_name_or_path.endswith(".safetensors")
):
# load from comfyui gemma3 checkpoint
tokenizer = GemmaTokenizerFast.from_pretrained(
"Lightricks/LTX-2", subfolder="tokenizer"
)
with init_empty_weights():
text_encoder = Gemma3ForConditionalGeneration(
Gemma3Config(
**{
"boi_token_index": 255999,
"bos_token_id": 2,
"eoi_token_index": 256000,
"eos_token_id": 106,
"image_token_index": 262144,
"initializer_range": 0.02,
"mm_tokens_per_image": 256,
"model_type": "gemma3",
"pad_token_id": 0,
"text_config": {
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": None,
"cache_implementation": "hybrid",
"final_logit_softcapping": None,
"head_dim": 256,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 3840,
"initializer_range": 0.02,
"intermediate_size": 15360,
"max_position_embeddings": 131072,
"model_type": "gemma3_text",
"num_attention_heads": 16,
"num_hidden_layers": 48,
"num_key_value_heads": 8,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000,
"rope_scaling": {"factor": 8.0, "rope_type": "linear"},
"rope_theta": 1000000,
"sliding_window": 1024,
"sliding_window_pattern": 6,
"torch_dtype": "bfloat16",
"use_cache": True,
"vocab_size": 262208,
},
"torch_dtype": "bfloat16",
"transformers_version": "4.51.3",
"unsloth_fixed": True,
"vision_config": {
"attention_dropout": 0.0,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 896,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
"torch_dtype": "bfloat16",
"vision_use_head": False,
},
}
)
)
te_state_dict = load_file(self.model_config.te_name_or_path)
te_state_dict = convert_comfy_gemma3_to_transformers(te_state_dict)
for key in te_state_dict:
te_state_dict[key] = te_state_dict[key].to(dtype)
text_encoder.load_state_dict(te_state_dict, assign=True, strict=True)
del te_state_dict
flush()
else:
if self.model_config.te_name_or_path is not None:
te_path = self.model_config.te_name_or_path
else:
te_path = base_model_path
tokenizer = GemmaTokenizerFast.from_pretrained(
te_path, subfolder="tokenizer"
)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
te_path, subfolder="text_encoder", dtype=dtype
)
# remove the vision tower
text_encoder.model.vision_tower = None
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
ignore_modules=[
text_encoder.model.language_model.base_model.embed_tokens
],
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Text Encoder")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
freeze(text_encoder)
flush()
self.print_and_status_update("Loading VAEs and other components")
if combined_state_dict is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, vae_prefix
)
vae = convert_ltx2_video_vae(original_vae_ckpt).to(dtype)
del original_vae_ckpt
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, audio_vae_prefix
)
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt).to(dtype)
del original_audio_vae_ckpt
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, dit_prefix
)
connectors = convert_ltx2_connectors(original_connectors_ckpt).to(dtype)
del original_connectors_ckpt
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, vocoder_prefix
)
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt).to(dtype)
del original_vocoder_ckpt
del combined_state_dict
flush()
else:
vae = AutoencoderKLLTX2Video.from_pretrained(
base_model_path, subfolder="vae", torch_dtype=dtype
)
audio_vae = AutoencoderKLLTX2Audio.from_pretrained(
base_model_path, subfolder="audio_vae", torch_dtype=dtype
)
connectors = LTX2TextConnectors.from_pretrained(
base_model_path, subfolder="connectors", torch_dtype=dtype
)
vocoder = LTX2Vocoder.from_pretrained(
base_model_path, subfolder="vocoder", torch_dtype=dtype
)
self.noise_scheduler = LTX2Model.get_train_scheduler()
self.print_and_status_update("Making pipe")
pipe: LTX2Pipeline = LTX2Pipeline(
scheduler=self.noise_scheduler,
vae=vae,
audio_vae=audio_vae,
text_encoder=None,
tokenizer=tokenizer,
connectors=connectors,
transformer=None,
vocoder=vocoder,
)
# for quantization, it works best to do these after making the pipe
pipe.text_encoder = text_encoder
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder]
tokenizer = [pipe.tokenizer]
# leave it on cpu for now
if not self.low_vram:
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# just to make sure everything is on the right device and dtype
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.audio_processor = AudioProcessor(
sample_rate=pipe.audio_sampling_rate,
mel_bins=audio_vae.config.mel_bins,
mel_hop_length=pipe.audio_hop_length,
n_fft=1024, # todo get this from vae if we can, I couldnt find it.
).to(self.device_torch, dtype=torch.float32)
self.print_and_status_update("Model Loaded")
@torch.no_grad()
def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None):
if device is None:
device = self.vae_device_torch
if dtype is None:
dtype = self.vae_torch_dtype
if self.vae.device == torch.device("cpu"):
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
image_list = [image.to(device, dtype=dtype) for image in image_list]
# Normalize shapes
norm_images = []
for image in image_list:
if image.ndim == 3:
# (C, H, W) -> (C, 1, H, W)
norm_images.append(image.unsqueeze(1))
elif image.ndim == 4:
# (T, C, H, W) -> (C, T, H, W)
norm_images.append(image.permute(1, 0, 2, 3))
else:
raise ValueError(f"Invalid image shape: {image.shape}")
# Stack to (B, C, T, H, W)
images = torch.stack(norm_images)
latents = self.vae.encode(images).latent_dist.mode()
# Normalize latents across the channel dimension [B, C, F, H, W]
scaling_factor = 1.0
latents_mean = self.pipeline.vae.latents_mean.view(1, -1, 1, 1, 1).to(
latents.device, latents.dtype
)
latents_std = self.pipeline.vae.latents_std.view(1, -1, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents.to(device, dtype=dtype)
def get_generation_pipeline(self):
scheduler = LTX2Model.get_train_scheduler()
pipeline: LTX2Pipeline = LTX2Pipeline(
scheduler=scheduler,
vae=unwrap_model(self.pipeline.vae),
audio_vae=unwrap_model(self.pipeline.audio_vae),
text_encoder=None,
tokenizer=unwrap_model(self.pipeline.tokenizer),
connectors=unwrap_model(self.pipeline.connectors),
transformer=None,
vocoder=unwrap_model(self.pipeline.vocoder),
)
pipeline.transformer = unwrap_model(self.model)
pipeline.text_encoder = unwrap_model(self.text_encoder[0])
# if self.low_vram:
# pipeline.enable_model_cpu_offload(device=self.device_torch)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: LTX2Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch)
is_video = gen_config.num_frames > 1
# override the generate single image to handle video + audio generation
if is_video:
gen_config._orig_save_image_function = gen_config.save_image
gen_config.save_image = partial(new_save_image_function, gen_config)
gen_config.log_image = partial(blank_log_image_function, gen_config)
# set output extension to mp4
gen_config.output_ext = "mp4"
# reactivate progress bar since this is slooooow
pipeline.set_progress_bar_config(disable=False)
pipeline = pipeline.to(self.device_torch)
# make sure dimensions are valid
bd = self.get_bucket_divisibility()
gen_config.height = (gen_config.height // bd) * bd
gen_config.width = (gen_config.width // bd) * bd
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
if gen_config.num_frames != 1:
if (gen_config.num_frames - 1) % 8 != 0:
gen_config.num_frames = ((gen_config.num_frames - 1) // 8) * 8 + 1
if self.low_vram:
# set vae to tile decode
pipeline.vae.enable_tiling(
tile_sample_min_height=256,
tile_sample_min_width=256,
tile_sample_min_num_frames=8,
tile_sample_stride_height=224,
tile_sample_stride_width=224,
tile_sample_stride_num_frames=4,
)
video, audio = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype
),
prompt_attention_mask=conditional_embeds.attention_mask.to(
self.device_torch
),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype
),
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(
self.device_torch
),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
num_frames=gen_config.num_frames,
generator=generator,
return_dict=False,
output_type="np" if is_video else "pil",
**extra,
)
if self.low_vram:
# Restore no tiling
pipeline.vae.use_tiling = False
if is_video:
# redurn as a dict, we will handle it with an override function
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
return {
"video": video[0],
"fps": gen_config.fps,
"audio": audio[0].float().cpu(),
"audio_sample_rate": pipeline.vocoder.config.output_sampling_rate, # should be 24000
"output_path": None,
}
else:
# shape = [1, frames, channels, height, width]
# make sure this is right
video = video[0] # list of pil images
audio = audio[0] # tensor
if gen_config.num_frames > 1:
return video # return the frames.
else:
# get just the first image
img = video[0]
return img
def encode_audio(self, batch: "DataLoaderBatchDTO"):
if self.pipeline.audio_vae.device == torch.device("cpu"):
self.pipeline.audio_vae.to(self.device_torch)
output_tensor = None
audio_num_frames = None
# do them seperatly for now
for audio_data in batch.audio_data:
waveform = audio_data["waveform"].to(
device=self.device_torch, dtype=torch.float32
)
sample_rate = audio_data["sample_rate"]
# Add batch dimension if needed: [channels, samples] -> [batch, channels, samples]
if waveform.dim() == 2:
waveform = waveform.unsqueeze(0)
if waveform.shape[1] == 1:
# make sure it is stereo
waveform = waveform.repeat(1, 2, 1)
# Convert waveform to mel spectrogram using AudioProcessor
mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate)
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
# Encode mel spectrogram to latents
latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode()
if audio_num_frames is None:
audio_num_frames = latents.shape[2] #(latents is [B, C, T, F])
packed_latents = self.pipeline._pack_audio_latents(
latents,
# patch_size=self.pipeline.transformer.config.audio_patch_size,
# patch_size_t=self.pipeline.transformer.config.audio_patch_size_t,
) # [B, L, C * M]
if output_tensor is None:
output_tensor = packed_latents
else:
output_tensor = torch.cat([output_tensor, packed_latents], dim=0)
# normalize latents, opposite of (latents * latents_std) + latents_mean
latents_mean = self.pipeline.audio_vae.latents_mean
latents_std = self.pipeline.audio_vae.latents_std
output_tensor = (output_tensor - latents_mean) / latents_std
return output_tensor, audio_num_frames
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
batch: "DataLoaderBatchDTO" = None,
**kwargs,
):
with torch.no_grad():
if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch)
batch_size, C, latent_num_frames, latent_height, latent_width = (
latent_model_input.shape
)
# todo get this somehow
frame_rate = 24
# check frame dimension
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
packed_latents = self.pipeline._pack_latents(
latent_model_input,
patch_size=self.pipeline.transformer_spatial_patch_size,
patch_size_t=self.pipeline.transformer_temporal_patch_size,
)
if batch.audio_tensor is not None:
# use audio from the batch if available
#(1, 190, 128)
raw_audio_latents, audio_num_frames = self.encode_audio(batch)
# add the audio targets to the batch for loss calculation later
audio_noise = torch.randn_like(raw_audio_latents)
batch.audio_target = (audio_noise - raw_audio_latents).detach()
audio_latents = self.add_noise(
raw_audio_latents,
audio_noise,
timestep,
).to(self.device_torch, dtype=self.torch_dtype)
else:
# no audio
num_mel_bins = self.pipeline.audio_vae.config.mel_bins
# latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
num_channels_latents_audio = (
self.pipeline.audio_vae.config.latent_channels
)
# audio latents are (1, 126, 128), audio_num_frames = 126
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
batch_size,
num_channels_latents=num_channels_latents_audio,
num_mel_bins=num_mel_bins,
num_frames=batch.tensor.shape[1],
frame_rate=frame_rate,
sampling_rate=self.pipeline.audio_sampling_rate,
hop_length=self.pipeline.audio_hop_length,
dtype=torch.float32,
device=self.transformer.device,
generator=None,
latents=None,
)
if self.pipeline.connectors.device != self.transformer.device:
self.pipeline.connectors.to(self.transformer.device)
# TODO this is how diffusers does this on inference, not sure I understand why, check this
additive_attention_mask = (
1 - text_embeddings.attention_mask.to(self.transformer.dtype)
) * -1000000.0
(
connector_prompt_embeds,
connector_audio_prompt_embeds,
connector_attention_mask,
) = self.pipeline.connectors(
text_embeddings.text_embeds, additive_attention_mask, additive_mask=True
)
# compute video and audio positional ids
video_coords = self.transformer.rope.prepare_video_coords(
packed_latents.shape[0],
latent_num_frames,
latent_height,
latent_width,
packed_latents.device,
fps=frame_rate,
)
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
audio_latents.shape[0], audio_num_frames, audio_latents.device
)
noise_pred_video, noise_pred_audio = self.transformer(
hidden_states=packed_latents,
audio_hidden_states=audio_latents.to(self.transformer.dtype),
encoder_hidden_states=connector_prompt_embeds,
audio_encoder_hidden_states=connector_audio_prompt_embeds,
timestep=timestep,
encoder_attention_mask=connector_attention_mask,
audio_encoder_attention_mask=connector_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
fps=frame_rate,
audio_num_frames=audio_num_frames,
video_coords=video_coords,
audio_coords=audio_coords,
# rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=None,
return_dict=False,
)
# add audio latent to batch if we had audio
if batch.audio_target is not None:
batch.audio_pred = noise_pred_audio
unpacked_output = self.pipeline._unpack_latents(
latents=noise_pred_video,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
patch_size=self.pipeline.transformer_spatial_patch_size,
patch_size_t=self.pipeline.transformer_temporal_patch_size,
)
return unpacked_output
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
prompt,
do_classifier_free_guidance=False,
device=self.device_torch,
)
pe = PromptEmbeds([prompt_embeds, None])
pe.attention_mask = prompt_attention_mask
return pe
def get_model_has_grad(self):
return False
def get_te_has_grad(self):
return False
def save_model(self, output_path, meta, save_dtype):
transformer: LTX2VideoTransformer3DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, "transformer"),
safe_serialization=True,
)
meta_path = os.path.join(output_path, "aitk_meta.yaml")
with open(meta_path, "w") as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get("noise")
batch = kwargs.get("batch")
return (noise - batch.latents).detach()
def get_base_model_version(self):
return "ltx2"
def get_transformer_block_names(self) -> Optional[List[str]]:
return ["transformer_blocks"]
def convert_lora_weights_before_save(self, state_dict):
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("transformer.", "diffusion_model.")
new_sd[new_key] = value
new_sd = convert_lora_diffusers_to_original(new_sd)
return new_sd
def convert_lora_weights_before_load(self, state_dict):
state_dict = convert_lora_original_to_diffusers(state_dict)
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd