Files
ai-toolkit/extensions_built_in/diffusion_models/ltx2/ltx2.py
2026-01-13 19:04:56 -07:00

904 lines
35 KiB
Python

from functools import partial
import os
from typing import List, Optional
import torch
import torchaudio
from transformers import Gemma3Config
import yaml
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from accelerate import init_empty_weights
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze
from toolkit.util.quantize import quantize, get_qtype, quantize_model
from toolkit.memory_management import MemoryManager
from safetensors.torch import load_file
from PIL import Image
try:
from diffusers import LTX2Pipeline, LTX2ImageToVideoPipeline
from diffusers.models.autoencoders import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
)
from diffusers.models.transformers import LTX2VideoTransformer3DModel
from diffusers.pipelines.ltx2.export_utils import encode_video
from transformers import (
Gemma3ForConditionalGeneration,
GemmaTokenizerFast,
)
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
from .convert_ltx2_to_diffusers import (
get_model_state_dict_from_combined_ckpt,
convert_ltx2_transformer,
convert_ltx2_video_vae,
convert_ltx2_audio_vae,
convert_ltx2_vocoder,
convert_ltx2_connectors,
dequantize_state_dict,
convert_comfy_gemma3_to_transformers,
convert_lora_original_to_diffusers,
convert_lora_diffusers_to_original,
)
except ImportError as e:
print("Diffusers import error:", e)
raise ImportError(
"Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt"
)
scheduler_config = {
"base_image_seq_len": 1024,
"base_shift": 0.95,
"invert_sigmas": False,
"max_image_seq_len": 4096,
"max_shift": 2.05,
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": 0.1,
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
dit_prefix = "model.diffusion_model."
vae_prefix = "vae."
audio_vae_prefix = "audio_vae."
vocoder_prefix = "vocoder."
def new_save_image_function(
self: GenerateImageConfig,
image, # will contain a dict that can be dumped ditectly into encode_video, just add output_path to it.
count: int = 0,
max_count: int = 0,
**kwargs,
):
# this replaces gen image config save image function so we can save the video with sound from ltx2
image["output_path"] = self.get_image_path(count, max_count)
# make sample directory if it does not exist
os.makedirs(os.path.dirname(image["output_path"]), exist_ok=True)
encode_video(**image)
flush()
def blank_log_image_function(self, *args, **kwargs):
# todo handle wandb logging of videos with audio
return
class AudioProcessor(torch.nn.Module):
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
def __init__(
self,
sample_rate: int,
mel_bins: int,
mel_hop_length: int,
n_fft: int,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=n_fft,
hop_length=mel_hop_length,
f_min=0.0,
f_max=sample_rate / 2.0,
n_mels=mel_bins,
window_fn=torch.hann_window,
center=True,
pad_mode="reflect",
power=1.0,
mel_scale="slaney",
norm="slaney",
)
def resample_waveform(
self,
waveform: torch.Tensor,
source_rate: int,
target_rate: int,
) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(device=waveform.device, dtype=waveform.dtype)
def waveform_to_mel(
self,
waveform: torch.Tensor,
waveform_sample_rate: int,
) -> torch.Tensor:
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
waveform = self.resample_waveform(
waveform, waveform_sample_rate, self.sample_rate
)
mel = self.mel_transform(waveform)
mel = torch.log(torch.clamp(mel, min=1e-5))
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
return mel.permute(0, 1, 3, 2).contiguous()
class LTX2Model(BaseModel):
arch = "ltx2"
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ["LTX2VideoTransformer3DModel"]
# defines if the model supports model paths. Only some will
self.supports_model_paths = True
# use the new format on this new model by default
self.use_old_lokr_format = False
self.audio_processor = None
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
return 32
def load_model(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading LTX2 model")
model_path = self.model_config.name_or_path
base_model_path = self.model_config.extras_name_or_path
combined_state_dict = None
self.print_and_status_update("Loading transformer")
# if we have a safetensors file it is a mono checkpoint
if os.path.exists(model_path) and model_path.endswith(".safetensors"):
combined_state_dict = load_file(model_path)
combined_state_dict = dequantize_state_dict(combined_state_dict)
if combined_state_dict is not None:
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, dit_prefix
)
transformer = convert_ltx2_transformer(original_dit_ckpt)
transformer = transformer.to(dtype)
else:
transformer_path = model_path
transformer_subfolder = "transformer"
if os.path.exists(transformer_path):
transformer_subfolder = None
transformer_path = os.path.join(transformer_path, "transformer")
# check if the path is a full checkpoint.
te_folder_path = os.path.join(model_path, "text_encoder")
# if we have the te, this folder is a full checkpoint, use it as the base
if os.path.exists(te_folder_path):
base_model_path = model_path
transformer = LTX2VideoTransformer3DModel.from_pretrained(
transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype
)
if self.model_config.quantize:
self.print_and_status_update("Quantizing Transformer")
quantize_model(self, transformer)
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_transformer_percent > 0
):
ignore_modules = []
for block in transformer.transformer_blocks:
ignore_modules.append(block.scale_shift_table)
ignore_modules.append(block.audio_scale_shift_table)
ignore_modules.append(block.video_a2v_cross_attn_scale_shift_table)
ignore_modules.append(block.audio_a2v_cross_attn_scale_shift_table)
ignore_modules.append(transformer.scale_shift_table)
ignore_modules.append(transformer.audio_scale_shift_table)
MemoryManager.attach(
transformer,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent,
ignore_modules=ignore_modules,
)
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer to CPU")
transformer.to("cpu")
flush()
self.print_and_status_update("Loading text encoder")
if (
self.model_config.te_name_or_path is not None
and self.model_config.te_name_or_path.endswith(".safetensors")
):
# load from comfyui gemma3 checkpoint
tokenizer = GemmaTokenizerFast.from_pretrained(
"Lightricks/LTX-2", subfolder="tokenizer"
)
with init_empty_weights():
text_encoder = Gemma3ForConditionalGeneration(
Gemma3Config(
**{
"boi_token_index": 255999,
"bos_token_id": 2,
"eoi_token_index": 256000,
"eos_token_id": 106,
"image_token_index": 262144,
"initializer_range": 0.02,
"mm_tokens_per_image": 256,
"model_type": "gemma3",
"pad_token_id": 0,
"text_config": {
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": None,
"cache_implementation": "hybrid",
"final_logit_softcapping": None,
"head_dim": 256,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 3840,
"initializer_range": 0.02,
"intermediate_size": 15360,
"max_position_embeddings": 131072,
"model_type": "gemma3_text",
"num_attention_heads": 16,
"num_hidden_layers": 48,
"num_key_value_heads": 8,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000,
"rope_scaling": {"factor": 8.0, "rope_type": "linear"},
"rope_theta": 1000000,
"sliding_window": 1024,
"sliding_window_pattern": 6,
"torch_dtype": "bfloat16",
"use_cache": True,
"vocab_size": 262208,
},
"torch_dtype": "bfloat16",
"transformers_version": "4.51.3",
"unsloth_fixed": True,
"vision_config": {
"attention_dropout": 0.0,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 896,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
"torch_dtype": "bfloat16",
"vision_use_head": False,
},
}
)
)
te_state_dict = load_file(self.model_config.te_name_or_path)
te_state_dict = convert_comfy_gemma3_to_transformers(te_state_dict)
for key in te_state_dict:
te_state_dict[key] = te_state_dict[key].to(dtype)
text_encoder.load_state_dict(te_state_dict, assign=True, strict=True)
del te_state_dict
flush()
else:
if self.model_config.te_name_or_path is not None:
te_path = self.model_config.te_name_or_path
else:
te_path = base_model_path
tokenizer = GemmaTokenizerFast.from_pretrained(
te_path, subfolder="tokenizer"
)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
te_path, subfolder="text_encoder", dtype=dtype
)
# remove the vision tower
text_encoder.model.vision_tower = None
flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
):
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
ignore_modules=[
text_encoder.model.language_model.base_model.embed_tokens
],
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing Text Encoder")
quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
freeze(text_encoder)
flush()
self.print_and_status_update("Loading VAEs and other components")
if combined_state_dict is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, vae_prefix
)
vae = convert_ltx2_video_vae(original_vae_ckpt).to(dtype)
del original_vae_ckpt
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, audio_vae_prefix
)
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt).to(dtype)
del original_audio_vae_ckpt
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, dit_prefix
)
connectors = convert_ltx2_connectors(original_connectors_ckpt).to(dtype)
del original_connectors_ckpt
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(
combined_state_dict, vocoder_prefix
)
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt).to(dtype)
del original_vocoder_ckpt
del combined_state_dict
flush()
else:
vae = AutoencoderKLLTX2Video.from_pretrained(
base_model_path, subfolder="vae", torch_dtype=dtype
)
audio_vae = AutoencoderKLLTX2Audio.from_pretrained(
base_model_path, subfolder="audio_vae", torch_dtype=dtype
)
connectors = LTX2TextConnectors.from_pretrained(
base_model_path, subfolder="connectors", torch_dtype=dtype
)
vocoder = LTX2Vocoder.from_pretrained(
base_model_path, subfolder="vocoder", torch_dtype=dtype
)
self.noise_scheduler = LTX2Model.get_train_scheduler()
self.print_and_status_update("Making pipe")
pipe: LTX2Pipeline = LTX2Pipeline(
scheduler=self.noise_scheduler,
vae=vae,
audio_vae=audio_vae,
text_encoder=None,
tokenizer=tokenizer,
connectors=connectors,
transformer=None,
vocoder=vocoder,
)
# for quantization, it works best to do these after making the pipe
pipe.text_encoder = text_encoder
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder]
tokenizer = [pipe.tokenizer]
# leave it on cpu for now
if not self.low_vram:
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# just to make sure everything is on the right device and dtype
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.audio_processor = AudioProcessor(
sample_rate=pipe.audio_sampling_rate,
mel_bins=audio_vae.config.mel_bins,
mel_hop_length=pipe.audio_hop_length,
n_fft=1024, # todo get this from vae if we can, I couldnt find it.
).to(self.device_torch, dtype=torch.float32)
self.print_and_status_update("Model Loaded")
@torch.no_grad()
def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None):
if device is None:
device = self.vae_device_torch
if dtype is None:
dtype = self.vae_torch_dtype
if self.vae.device == torch.device("cpu"):
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
image_list = [image.to(device, dtype=dtype) for image in image_list]
# Normalize shapes
norm_images = []
for image in image_list:
if image.ndim == 3:
# (C, H, W) -> (C, 1, H, W)
norm_images.append(image.unsqueeze(1))
elif image.ndim == 4:
# (T, C, H, W) -> (C, T, H, W)
norm_images.append(image.permute(1, 0, 2, 3))
else:
raise ValueError(f"Invalid image shape: {image.shape}")
# Stack to (B, C, T, H, W)
images = torch.stack(norm_images)
latents = self.vae.encode(images).latent_dist.mode()
# Normalize latents across the channel dimension [B, C, F, H, W]
scaling_factor = 1.0
latents_mean = self.pipeline.vae.latents_mean.view(1, -1, 1, 1, 1).to(
latents.device, latents.dtype
)
latents_std = self.pipeline.vae.latents_std.view(1, -1, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents.to(device, dtype=dtype)
def get_generation_pipeline(self):
scheduler = LTX2Model.get_train_scheduler()
pipeline: LTX2Pipeline = LTX2Pipeline(
scheduler=scheduler,
vae=unwrap_model(self.pipeline.vae),
audio_vae=unwrap_model(self.pipeline.audio_vae),
text_encoder=None,
tokenizer=unwrap_model(self.pipeline.tokenizer),
connectors=unwrap_model(self.pipeline.connectors),
transformer=None,
vocoder=unwrap_model(self.pipeline.vocoder),
)
pipeline.transformer = unwrap_model(self.model)
pipeline.text_encoder = unwrap_model(self.text_encoder[0])
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: LTX2Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch)
# handle control image
if gen_config.ctrl_img is not None:
# switch to image to video pipeline
pipeline = LTX2ImageToVideoPipeline(
scheduler=pipeline.scheduler,
vae=pipeline.vae,
audio_vae=pipeline.audio_vae,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
connectors=pipeline.connectors,
transformer=pipeline.transformer,
vocoder=pipeline.vocoder,
)
is_video = gen_config.num_frames > 1
# override the generate single image to handle video + audio generation
if is_video:
gen_config._orig_save_image_function = gen_config.save_image
gen_config.save_image = partial(new_save_image_function, gen_config)
gen_config.log_image = partial(blank_log_image_function, gen_config)
# set output extension to mp4
gen_config.output_ext = "mp4"
# reactivate progress bar since this is slooooow
pipeline.set_progress_bar_config(disable=False)
pipeline = pipeline.to(self.device_torch)
# make sure dimensions are valid
bd = self.get_bucket_divisibility()
gen_config.height = (gen_config.height // bd) * bd
gen_config.width = (gen_config.width // bd) * bd
# handle control image
if gen_config.ctrl_img is not None:
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
# resize the control image
control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS)
# add the control image to the extra dict
extra["image"] = control_img
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
if gen_config.num_frames != 1:
if (gen_config.num_frames - 1) % 8 != 0:
gen_config.num_frames = ((gen_config.num_frames - 1) // 8) * 8 + 1
if self.low_vram:
# set vae to tile decode
pipeline.vae.enable_tiling(
tile_sample_min_height=256,
tile_sample_min_width=256,
tile_sample_min_num_frames=8,
tile_sample_stride_height=224,
tile_sample_stride_width=224,
tile_sample_stride_num_frames=4,
)
video, audio = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype
),
prompt_attention_mask=conditional_embeds.attention_mask.to(
self.device_torch
),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype
),
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(
self.device_torch
),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
num_frames=gen_config.num_frames,
generator=generator,
return_dict=False,
output_type="np" if is_video else "pil",
**extra,
)
if self.low_vram:
# Restore no tiling
pipeline.vae.use_tiling = False
if is_video:
# redurn as a dict, we will handle it with an override function
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
return {
"video": video[0],
"fps": gen_config.fps,
"audio": audio[0].float().cpu(),
"audio_sample_rate": pipeline.vocoder.config.output_sampling_rate, # should be 24000
"output_path": None,
}
else:
# shape = [1, frames, channels, height, width]
# make sure this is right
video = video[0] # list of pil images
audio = audio[0] # tensor
if gen_config.num_frames > 1:
return video # return the frames.
else:
# get just the first image
img = video[0]
return img
def encode_audio(self, batch: "DataLoaderBatchDTO"):
if self.pipeline.audio_vae.device == torch.device("cpu"):
self.pipeline.audio_vae.to(self.device_torch)
output_tensor = None
audio_num_frames = None
# do them seperatly for now
for audio_data in batch.audio_data:
waveform = audio_data["waveform"].to(
device=self.device_torch, dtype=torch.float32
)
sample_rate = audio_data["sample_rate"]
# Add batch dimension if needed: [channels, samples] -> [batch, channels, samples]
if waveform.dim() == 2:
waveform = waveform.unsqueeze(0)
if waveform.shape[1] == 1:
# make sure it is stereo
waveform = waveform.repeat(1, 2, 1)
# Convert waveform to mel spectrogram using AudioProcessor
mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate)
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
# Encode mel spectrogram to latents
latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode()
if audio_num_frames is None:
audio_num_frames = latents.shape[2] #(latents is [B, C, T, F])
packed_latents = self.pipeline._pack_audio_latents(
latents,
# patch_size=self.pipeline.transformer.config.audio_patch_size,
# patch_size_t=self.pipeline.transformer.config.audio_patch_size_t,
) # [B, L, C * M]
if output_tensor is None:
output_tensor = packed_latents
else:
output_tensor = torch.cat([output_tensor, packed_latents], dim=0)
# normalize latents, opposite of (latents * latents_std) + latents_mean
latents_mean = self.pipeline.audio_vae.latents_mean
latents_std = self.pipeline.audio_vae.latents_std
output_tensor = (output_tensor - latents_mean) / latents_std
return output_tensor, audio_num_frames
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
batch: "DataLoaderBatchDTO" = None,
**kwargs,
):
with torch.no_grad():
if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch)
batch_size, C, latent_num_frames, latent_height, latent_width = (
latent_model_input.shape
)
video_timestep = timestep.clone()
# i2v from first frame
if batch.dataset_config.do_i2v:
# videos come in (bs, num_frames, channels, height, width)
# images come in (bs, channels, height, width)
frames = batch.tensor
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
# first frame doesnt have time dim, add it back
init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype)
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
# First condition is image latents and those should be kept clean.
conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype)
conditioning_mask[:, :, 0] = 1.0
# use conditioning mask to replace latents
latent_model_input = (
latent_model_input * (1 - conditioning_mask)
+ init_latents * conditioning_mask
)
# set video timestep
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
# todo get this somehow
frame_rate = 24
# check frame dimension
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
packed_latents = self.pipeline._pack_latents(
latent_model_input,
patch_size=self.pipeline.transformer_spatial_patch_size,
patch_size_t=self.pipeline.transformer_temporal_patch_size,
)
if batch.audio_tensor is not None:
# use audio from the batch if available
#(1, 190, 128)
raw_audio_latents, audio_num_frames = self.encode_audio(batch)
# add the audio targets to the batch for loss calculation later
audio_noise = torch.randn_like(raw_audio_latents)
batch.audio_target = (audio_noise - raw_audio_latents).detach()
audio_latents = self.add_noise(
raw_audio_latents,
audio_noise,
timestep,
).to(self.device_torch, dtype=self.torch_dtype)
else:
# no audio
num_mel_bins = self.pipeline.audio_vae.config.mel_bins
# latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
num_channels_latents_audio = (
self.pipeline.audio_vae.config.latent_channels
)
# audio latents are (1, 126, 128), audio_num_frames = 126
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
batch_size,
num_channels_latents=num_channels_latents_audio,
num_mel_bins=num_mel_bins,
num_frames=batch.tensor.shape[1],
frame_rate=frame_rate,
sampling_rate=self.pipeline.audio_sampling_rate,
hop_length=self.pipeline.audio_hop_length,
dtype=torch.float32,
device=self.transformer.device,
generator=None,
latents=None,
)
if self.pipeline.connectors.device != self.transformer.device:
self.pipeline.connectors.to(self.transformer.device)
# TODO this is how diffusers does this on inference, not sure I understand why, check this
additive_attention_mask = (
1 - text_embeddings.attention_mask.to(self.transformer.dtype)
) * -1000000.0
(
connector_prompt_embeds,
connector_audio_prompt_embeds,
connector_attention_mask,
) = self.pipeline.connectors(
text_embeddings.text_embeds, additive_attention_mask, additive_mask=True
)
# compute video and audio positional ids
video_coords = self.transformer.rope.prepare_video_coords(
packed_latents.shape[0],
latent_num_frames,
latent_height,
latent_width,
packed_latents.device,
fps=frame_rate,
)
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
audio_latents.shape[0], audio_num_frames, audio_latents.device
)
noise_pred_video, noise_pred_audio = self.transformer(
hidden_states=packed_latents,
audio_hidden_states=audio_latents.to(self.transformer.dtype),
encoder_hidden_states=connector_prompt_embeds,
audio_encoder_hidden_states=connector_audio_prompt_embeds,
timestep=video_timestep,
audio_timestep=timestep,
encoder_attention_mask=connector_attention_mask,
audio_encoder_attention_mask=connector_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
fps=frame_rate,
audio_num_frames=audio_num_frames,
video_coords=video_coords,
audio_coords=audio_coords,
# rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=None,
return_dict=False,
)
# add audio latent to batch if we had audio
if batch.audio_target is not None:
batch.audio_pred = noise_pred_audio
unpacked_output = self.pipeline._unpack_latents(
latents=noise_pred_video,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
patch_size=self.pipeline.transformer_spatial_patch_size,
patch_size_t=self.pipeline.transformer_temporal_patch_size,
)
return unpacked_output
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
prompt,
do_classifier_free_guidance=False,
device=self.device_torch,
)
pe = PromptEmbeds([prompt_embeds, None])
pe.attention_mask = prompt_attention_mask
return pe
def get_model_has_grad(self):
return False
def get_te_has_grad(self):
return False
def save_model(self, output_path, meta, save_dtype):
transformer: LTX2VideoTransformer3DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, "transformer"),
safe_serialization=True,
)
meta_path = os.path.join(output_path, "aitk_meta.yaml")
with open(meta_path, "w") as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get("noise")
batch = kwargs.get("batch")
return (noise - batch.latents).detach()
def get_base_model_version(self):
return "ltx2"
def get_transformer_block_names(self) -> Optional[List[str]]:
return ["transformer_blocks"]
def convert_lora_weights_before_save(self, state_dict):
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("transformer.", "diffusion_model.")
new_sd[new_key] = value
new_sd = convert_lora_diffusers_to_original(new_sd)
return new_sd
def convert_lora_weights_before_load(self, state_dict):
state_dict = convert_lora_original_to_diffusers(state_dict)
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd