mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add LTX-2 Support (#644)
* WIP, adding support for LTX2 * Training on images working * Fix loading comfy models * Handle converting and deconverting lora so it matches original format * Reworked ui to habdle ltx and propert dataset default overwriting. * Update the way lokr saves to it is more compatable with comfy * Audio loading and synchronization/resampling is working * Add audio to training. Does it work? Maybe, still testing. * Fixed fps default issue for sound * Have ui set fps for accurate audio mapping on ltx * Added audio procession options to the ui for ltx * Clean up requirements
This commit is contained in:
852
extensions_built_in/diffusion_models/ltx2/ltx2.py
Normal file
852
extensions_built_in/diffusion_models/ltx2/ltx2.py
Normal file
@@ -0,0 +1,852 @@
|
||||
from functools import partial
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers import Gemma3Config
|
||||
import yaml
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||
CustomFlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from accelerate import init_empty_weights
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from optimum.quanto import freeze
|
||||
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||
from toolkit.memory_management import MemoryManager
|
||||
from safetensors.torch import load_file
|
||||
|
||||
try:
|
||||
from diffusers import LTX2Pipeline
|
||||
from diffusers.models.autoencoders import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
)
|
||||
from diffusers.models.transformers import LTX2VideoTransformer3DModel
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
from transformers import (
|
||||
Gemma3ForConditionalGeneration,
|
||||
GemmaTokenizerFast,
|
||||
)
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
|
||||
from .convert_ltx2_to_diffusers import (
|
||||
get_model_state_dict_from_combined_ckpt,
|
||||
convert_ltx2_transformer,
|
||||
convert_ltx2_video_vae,
|
||||
convert_ltx2_audio_vae,
|
||||
convert_ltx2_vocoder,
|
||||
convert_ltx2_connectors,
|
||||
dequantize_state_dict,
|
||||
convert_comfy_gemma3_to_transformers,
|
||||
convert_lora_original_to_diffusers,
|
||||
convert_lora_diffusers_to_original,
|
||||
)
|
||||
except ImportError as e:
|
||||
print("Diffusers import error:", e)
|
||||
raise ImportError(
|
||||
"Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt"
|
||||
)
|
||||
|
||||
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 1024,
|
||||
"base_shift": 0.95,
|
||||
"invert_sigmas": False,
|
||||
"max_image_seq_len": 4096,
|
||||
"max_shift": 2.05,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 1.0,
|
||||
"shift_terminal": 0.1,
|
||||
"stochastic_sampling": False,
|
||||
"time_shift_type": "exponential",
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": True,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_karras_sigmas": False,
|
||||
}
|
||||
|
||||
dit_prefix = "model.diffusion_model."
|
||||
vae_prefix = "vae."
|
||||
audio_vae_prefix = "audio_vae."
|
||||
vocoder_prefix = "vocoder."
|
||||
|
||||
|
||||
def new_save_image_function(
|
||||
self: GenerateImageConfig,
|
||||
image, # will contain a dict that can be dumped ditectly into encode_video, just add output_path to it.
|
||||
count: int = 0,
|
||||
max_count: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
# this replaces gen image config save image function so we can save the video with sound from ltx2
|
||||
image["output_path"] = self.get_image_path(count, max_count)
|
||||
# make sample directory if it does not exist
|
||||
os.makedirs(os.path.dirname(image["output_path"]), exist_ok=True)
|
||||
encode_video(**image)
|
||||
flush()
|
||||
|
||||
|
||||
def blank_log_image_function(self, *args, **kwargs):
|
||||
# todo handle wandb logging of videos with audio
|
||||
return
|
||||
|
||||
|
||||
class AudioProcessor(torch.nn.Module):
|
||||
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int,
|
||||
mel_bins: int,
|
||||
mel_hop_length: int,
|
||||
n_fft: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=n_fft,
|
||||
hop_length=mel_hop_length,
|
||||
f_min=0.0,
|
||||
f_max=sample_rate / 2.0,
|
||||
n_mels=mel_bins,
|
||||
window_fn=torch.hann_window,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
power=1.0,
|
||||
mel_scale="slaney",
|
||||
norm="slaney",
|
||||
)
|
||||
|
||||
def resample_waveform(
|
||||
self,
|
||||
waveform: torch.Tensor,
|
||||
source_rate: int,
|
||||
target_rate: int,
|
||||
) -> torch.Tensor:
|
||||
"""Resample waveform to target sample rate if needed."""
|
||||
if source_rate == target_rate:
|
||||
return waveform
|
||||
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||
return resampled.to(device=waveform.device, dtype=waveform.dtype)
|
||||
|
||||
def waveform_to_mel(
|
||||
self,
|
||||
waveform: torch.Tensor,
|
||||
waveform_sample_rate: int,
|
||||
) -> torch.Tensor:
|
||||
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
|
||||
waveform = self.resample_waveform(
|
||||
waveform, waveform_sample_rate, self.sample_rate
|
||||
)
|
||||
|
||||
mel = self.mel_transform(waveform)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
|
||||
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
|
||||
return mel.permute(0, 1, 3, 2).contiguous()
|
||||
|
||||
|
||||
class LTX2Model(BaseModel):
|
||||
arch = "ltx2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ["LTX2VideoTransformer3DModel"]
|
||||
# defines if the model supports model paths. Only some will
|
||||
self.supports_model_paths = True
|
||||
# use the new format on this new model by default
|
||||
self.use_old_lokr_format = False
|
||||
self.audio_processor = None
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
return 32
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
self.print_and_status_update("Loading LTX2 model")
|
||||
model_path = self.model_config.name_or_path
|
||||
base_model_path = self.model_config.extras_name_or_path
|
||||
|
||||
combined_state_dict = None
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
# if we have a safetensors file it is a mono checkpoint
|
||||
if os.path.exists(model_path) and model_path.endswith(".safetensors"):
|
||||
combined_state_dict = load_file(model_path)
|
||||
combined_state_dict = dequantize_state_dict(combined_state_dict)
|
||||
|
||||
if combined_state_dict is not None:
|
||||
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(
|
||||
combined_state_dict, dit_prefix
|
||||
)
|
||||
transformer = convert_ltx2_transformer(original_dit_ckpt)
|
||||
transformer = transformer.to(dtype)
|
||||
else:
|
||||
transformer_path = model_path
|
||||
transformer_subfolder = "transformer"
|
||||
if os.path.exists(transformer_path):
|
||||
transformer_subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, "transformer")
|
||||
# check if the path is a full checkpoint.
|
||||
te_folder_path = os.path.join(model_path, "text_encoder")
|
||||
# if we have the te, this folder is a full checkpoint, use it as the base
|
||||
if os.path.exists(te_folder_path):
|
||||
base_model_path = model_path
|
||||
|
||||
transformer = LTX2VideoTransformer3DModel.from_pretrained(
|
||||
transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype
|
||||
)
|
||||
|
||||
if self.model_config.quantize:
|
||||
self.print_and_status_update("Quantizing Transformer")
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
if (
|
||||
self.model_config.layer_offloading
|
||||
and self.model_config.layer_offloading_transformer_percent > 0
|
||||
):
|
||||
ignore_modules = []
|
||||
for block in transformer.transformer_blocks:
|
||||
ignore_modules.append(block.scale_shift_table)
|
||||
ignore_modules.append(block.audio_scale_shift_table)
|
||||
ignore_modules.append(block.video_a2v_cross_attn_scale_shift_table)
|
||||
ignore_modules.append(block.audio_a2v_cross_attn_scale_shift_table)
|
||||
ignore_modules.append(transformer.scale_shift_table)
|
||||
ignore_modules.append(transformer.audio_scale_shift_table)
|
||||
MemoryManager.attach(
|
||||
transformer,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_transformer_percent,
|
||||
ignore_modules=ignore_modules,
|
||||
)
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer to CPU")
|
||||
transformer.to("cpu")
|
||||
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading text encoder")
|
||||
if (
|
||||
self.model_config.te_name_or_path is not None
|
||||
and self.model_config.te_name_or_path.endswith(".safetensors")
|
||||
):
|
||||
# load from comfyui gemma3 checkpoint
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained(
|
||||
"Lightricks/LTX-2", subfolder="tokenizer"
|
||||
)
|
||||
|
||||
with init_empty_weights():
|
||||
text_encoder = Gemma3ForConditionalGeneration(
|
||||
Gemma3Config(
|
||||
**{
|
||||
"boi_token_index": 255999,
|
||||
"bos_token_id": 2,
|
||||
"eoi_token_index": 256000,
|
||||
"eos_token_id": 106,
|
||||
"image_token_index": 262144,
|
||||
"initializer_range": 0.02,
|
||||
"mm_tokens_per_image": 256,
|
||||
"model_type": "gemma3",
|
||||
"pad_token_id": 0,
|
||||
"text_config": {
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": None,
|
||||
"cache_implementation": "hybrid",
|
||||
"final_logit_softcapping": None,
|
||||
"head_dim": 256,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 3840,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 15360,
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "gemma3_text",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 48,
|
||||
"num_key_value_heads": 8,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_local_base_freq": 10000,
|
||||
"rope_scaling": {"factor": 8.0, "rope_type": "linear"},
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": 1024,
|
||||
"sliding_window_pattern": 6,
|
||||
"torch_dtype": "bfloat16",
|
||||
"use_cache": True,
|
||||
"vocab_size": 262208,
|
||||
},
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.3",
|
||||
"unsloth_fixed": True,
|
||||
"vision_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 896,
|
||||
"intermediate_size": 4304,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"torch_dtype": "bfloat16",
|
||||
"vision_use_head": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
te_state_dict = load_file(self.model_config.te_name_or_path)
|
||||
te_state_dict = convert_comfy_gemma3_to_transformers(te_state_dict)
|
||||
for key in te_state_dict:
|
||||
te_state_dict[key] = te_state_dict[key].to(dtype)
|
||||
|
||||
text_encoder.load_state_dict(te_state_dict, assign=True, strict=True)
|
||||
del te_state_dict
|
||||
flush()
|
||||
else:
|
||||
if self.model_config.te_name_or_path is not None:
|
||||
te_path = self.model_config.te_name_or_path
|
||||
else:
|
||||
te_path = base_model_path
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained(
|
||||
te_path, subfolder="tokenizer"
|
||||
)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
te_path, subfolder="text_encoder", dtype=dtype
|
||||
)
|
||||
|
||||
# remove the vision tower
|
||||
text_encoder.model.vision_tower = None
|
||||
flush()
|
||||
|
||||
if (
|
||||
self.model_config.layer_offloading
|
||||
and self.model_config.layer_offloading_text_encoder_percent > 0
|
||||
):
|
||||
MemoryManager.attach(
|
||||
text_encoder,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
||||
ignore_modules=[
|
||||
text_encoder.model.language_model.base_model.embed_tokens
|
||||
],
|
||||
)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Text Encoder")
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading VAEs and other components")
|
||||
if combined_state_dict is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(
|
||||
combined_state_dict, vae_prefix
|
||||
)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt).to(dtype)
|
||||
del original_vae_ckpt
|
||||
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(
|
||||
combined_state_dict, audio_vae_prefix
|
||||
)
|
||||
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt).to(dtype)
|
||||
del original_audio_vae_ckpt
|
||||
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(
|
||||
combined_state_dict, dit_prefix
|
||||
)
|
||||
connectors = convert_ltx2_connectors(original_connectors_ckpt).to(dtype)
|
||||
del original_connectors_ckpt
|
||||
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(
|
||||
combined_state_dict, vocoder_prefix
|
||||
)
|
||||
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt).to(dtype)
|
||||
del original_vocoder_ckpt
|
||||
del combined_state_dict
|
||||
flush()
|
||||
else:
|
||||
vae = AutoencoderKLLTX2Video.from_pretrained(
|
||||
base_model_path, subfolder="vae", torch_dtype=dtype
|
||||
)
|
||||
audio_vae = AutoencoderKLLTX2Audio.from_pretrained(
|
||||
base_model_path, subfolder="audio_vae", torch_dtype=dtype
|
||||
)
|
||||
|
||||
connectors = LTX2TextConnectors.from_pretrained(
|
||||
base_model_path, subfolder="connectors", torch_dtype=dtype
|
||||
)
|
||||
|
||||
vocoder = LTX2Vocoder.from_pretrained(
|
||||
base_model_path, subfolder="vocoder", torch_dtype=dtype
|
||||
)
|
||||
|
||||
self.noise_scheduler = LTX2Model.get_train_scheduler()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
pipe: LTX2Pipeline = LTX2Pipeline(
|
||||
scheduler=self.noise_scheduler,
|
||||
vae=vae,
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=None,
|
||||
tokenizer=tokenizer,
|
||||
connectors=connectors,
|
||||
transformer=None,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
# for quantization, it works best to do these after making the pipe
|
||||
pipe.text_encoder = text_encoder
|
||||
pipe.transformer = transformer
|
||||
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = [pipe.text_encoder]
|
||||
tokenizer = [pipe.tokenizer]
|
||||
|
||||
# leave it on cpu for now
|
||||
if not self.low_vram:
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
|
||||
flush()
|
||||
# just to make sure everything is on the right device and dtype
|
||||
text_encoder[0].to(self.device_torch)
|
||||
text_encoder[0].requires_grad_(False)
|
||||
text_encoder[0].eval()
|
||||
flush()
|
||||
|
||||
# save it to the model class
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder # list of text encoders
|
||||
self.tokenizer = tokenizer # list of tokenizers
|
||||
self.model = pipe.transformer
|
||||
self.pipeline = pipe
|
||||
|
||||
self.audio_processor = AudioProcessor(
|
||||
sample_rate=pipe.audio_sampling_rate,
|
||||
mel_bins=audio_vae.config.mel_bins,
|
||||
mel_hop_length=pipe.audio_hop_length,
|
||||
n_fft=1024, # todo get this from vae if we can, I couldnt find it.
|
||||
).to(self.device_torch, dtype=torch.float32)
|
||||
|
||||
self.print_and_status_update("Model Loaded")
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None):
|
||||
if device is None:
|
||||
device = self.vae_device_torch
|
||||
if dtype is None:
|
||||
dtype = self.vae_torch_dtype
|
||||
|
||||
if self.vae.device == torch.device("cpu"):
|
||||
self.vae.to(device)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
|
||||
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||
|
||||
# Normalize shapes
|
||||
norm_images = []
|
||||
for image in image_list:
|
||||
if image.ndim == 3:
|
||||
# (C, H, W) -> (C, 1, H, W)
|
||||
norm_images.append(image.unsqueeze(1))
|
||||
elif image.ndim == 4:
|
||||
# (T, C, H, W) -> (C, T, H, W)
|
||||
norm_images.append(image.permute(1, 0, 2, 3))
|
||||
else:
|
||||
raise ValueError(f"Invalid image shape: {image.shape}")
|
||||
|
||||
# Stack to (B, C, T, H, W)
|
||||
images = torch.stack(norm_images)
|
||||
|
||||
latents = self.vae.encode(images).latent_dist.mode()
|
||||
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
scaling_factor = 1.0
|
||||
latents_mean = self.pipeline.vae.latents_mean.view(1, -1, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents_std = self.pipeline.vae.latents_std.view(1, -1, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
|
||||
return latents.to(device, dtype=dtype)
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = LTX2Model.get_train_scheduler()
|
||||
|
||||
pipeline: LTX2Pipeline = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
vae=unwrap_model(self.pipeline.vae),
|
||||
audio_vae=unwrap_model(self.pipeline.audio_vae),
|
||||
text_encoder=None,
|
||||
tokenizer=unwrap_model(self.pipeline.tokenizer),
|
||||
connectors=unwrap_model(self.pipeline.connectors),
|
||||
transformer=None,
|
||||
vocoder=unwrap_model(self.pipeline.vocoder),
|
||||
)
|
||||
pipeline.transformer = unwrap_model(self.model)
|
||||
pipeline.text_encoder = unwrap_model(self.text_encoder[0])
|
||||
|
||||
# if self.low_vram:
|
||||
# pipeline.enable_model_cpu_offload(device=self.device_torch)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: LTX2Pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
if self.model.device == torch.device("cpu"):
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
is_video = gen_config.num_frames > 1
|
||||
# override the generate single image to handle video + audio generation
|
||||
if is_video:
|
||||
gen_config._orig_save_image_function = gen_config.save_image
|
||||
gen_config.save_image = partial(new_save_image_function, gen_config)
|
||||
gen_config.log_image = partial(blank_log_image_function, gen_config)
|
||||
# set output extension to mp4
|
||||
gen_config.output_ext = "mp4"
|
||||
|
||||
# reactivate progress bar since this is slooooow
|
||||
pipeline.set_progress_bar_config(disable=False)
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
# make sure dimensions are valid
|
||||
bd = self.get_bucket_divisibility()
|
||||
gen_config.height = (gen_config.height // bd) * bd
|
||||
gen_config.width = (gen_config.width // bd) * bd
|
||||
|
||||
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
|
||||
if gen_config.num_frames != 1:
|
||||
if (gen_config.num_frames - 1) % 8 != 0:
|
||||
gen_config.num_frames = ((gen_config.num_frames - 1) // 8) * 8 + 1
|
||||
|
||||
if self.low_vram:
|
||||
# set vae to tile decode
|
||||
pipeline.vae.enable_tiling(
|
||||
tile_sample_min_height=256,
|
||||
tile_sample_min_width=256,
|
||||
tile_sample_min_num_frames=8,
|
||||
tile_sample_stride_height=224,
|
||||
tile_sample_stride_width=224,
|
||||
tile_sample_stride_num_frames=4,
|
||||
)
|
||||
|
||||
video, audio = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype
|
||||
),
|
||||
prompt_attention_mask=conditional_embeds.attention_mask.to(
|
||||
self.device_torch
|
||||
),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
||||
self.device_torch, dtype=self.torch_dtype
|
||||
),
|
||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(
|
||||
self.device_torch
|
||||
),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
num_frames=gen_config.num_frames,
|
||||
generator=generator,
|
||||
return_dict=False,
|
||||
output_type="np" if is_video else "pil",
|
||||
**extra,
|
||||
)
|
||||
if self.low_vram:
|
||||
# Restore no tiling
|
||||
pipeline.vae.use_tiling = False
|
||||
|
||||
if is_video:
|
||||
# redurn as a dict, we will handle it with an override function
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
return {
|
||||
"video": video[0],
|
||||
"fps": gen_config.fps,
|
||||
"audio": audio[0].float().cpu(),
|
||||
"audio_sample_rate": pipeline.vocoder.config.output_sampling_rate, # should be 24000
|
||||
"output_path": None,
|
||||
}
|
||||
else:
|
||||
# shape = [1, frames, channels, height, width]
|
||||
# make sure this is right
|
||||
video = video[0] # list of pil images
|
||||
audio = audio[0] # tensor
|
||||
if gen_config.num_frames > 1:
|
||||
return video # return the frames.
|
||||
else:
|
||||
# get just the first image
|
||||
img = video[0]
|
||||
return img
|
||||
|
||||
def encode_audio(self, batch: "DataLoaderBatchDTO"):
|
||||
if self.pipeline.audio_vae.device == torch.device("cpu"):
|
||||
self.pipeline.audio_vae.to(self.device_torch)
|
||||
|
||||
output_tensor = None
|
||||
audio_num_frames = None
|
||||
|
||||
# do them seperatly for now
|
||||
for audio_data in batch.audio_data:
|
||||
waveform = audio_data["waveform"].to(
|
||||
device=self.device_torch, dtype=torch.float32
|
||||
)
|
||||
sample_rate = audio_data["sample_rate"]
|
||||
|
||||
# Add batch dimension if needed: [channels, samples] -> [batch, channels, samples]
|
||||
if waveform.dim() == 2:
|
||||
waveform = waveform.unsqueeze(0)
|
||||
|
||||
if waveform.shape[1] == 1:
|
||||
# make sure it is stereo
|
||||
waveform = waveform.repeat(1, 2, 1)
|
||||
|
||||
# Convert waveform to mel spectrogram using AudioProcessor
|
||||
mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate)
|
||||
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
|
||||
|
||||
# Encode mel spectrogram to latents
|
||||
latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode()
|
||||
|
||||
if audio_num_frames is None:
|
||||
audio_num_frames = latents.shape[2] #(latents is [B, C, T, F])
|
||||
|
||||
packed_latents = self.pipeline._pack_audio_latents(
|
||||
latents,
|
||||
# patch_size=self.pipeline.transformer.config.audio_patch_size,
|
||||
# patch_size_t=self.pipeline.transformer.config.audio_patch_size_t,
|
||||
) # [B, L, C * M]
|
||||
if output_tensor is None:
|
||||
output_tensor = packed_latents
|
||||
else:
|
||||
output_tensor = torch.cat([output_tensor, packed_latents], dim=0)
|
||||
|
||||
# normalize latents, opposite of (latents * latents_std) + latents_mean
|
||||
latents_mean = self.pipeline.audio_vae.latents_mean
|
||||
latents_std = self.pipeline.audio_vae.latents_std
|
||||
output_tensor = (output_tensor - latents_mean) / latents_std
|
||||
return output_tensor, audio_num_frames
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
batch: "DataLoaderBatchDTO" = None,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
if self.model.device == torch.device("cpu"):
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
batch_size, C, latent_num_frames, latent_height, latent_width = (
|
||||
latent_model_input.shape
|
||||
)
|
||||
|
||||
# todo get this somehow
|
||||
frame_rate = 24
|
||||
# check frame dimension
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
packed_latents = self.pipeline._pack_latents(
|
||||
latent_model_input,
|
||||
patch_size=self.pipeline.transformer_spatial_patch_size,
|
||||
patch_size_t=self.pipeline.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
if batch.audio_tensor is not None:
|
||||
# use audio from the batch if available
|
||||
#(1, 190, 128)
|
||||
raw_audio_latents, audio_num_frames = self.encode_audio(batch)
|
||||
# add the audio targets to the batch for loss calculation later
|
||||
audio_noise = torch.randn_like(raw_audio_latents)
|
||||
batch.audio_target = (audio_noise - raw_audio_latents).detach()
|
||||
audio_latents = self.add_noise(
|
||||
raw_audio_latents,
|
||||
audio_noise,
|
||||
timestep,
|
||||
).to(self.device_torch, dtype=self.torch_dtype)
|
||||
|
||||
else:
|
||||
# no audio
|
||||
num_mel_bins = self.pipeline.audio_vae.config.mel_bins
|
||||
# latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
num_channels_latents_audio = (
|
||||
self.pipeline.audio_vae.config.latent_channels
|
||||
)
|
||||
# audio latents are (1, 126, 128), audio_num_frames = 126
|
||||
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
|
||||
batch_size,
|
||||
num_channels_latents=num_channels_latents_audio,
|
||||
num_mel_bins=num_mel_bins,
|
||||
num_frames=batch.tensor.shape[1],
|
||||
frame_rate=frame_rate,
|
||||
sampling_rate=self.pipeline.audio_sampling_rate,
|
||||
hop_length=self.pipeline.audio_hop_length,
|
||||
dtype=torch.float32,
|
||||
device=self.transformer.device,
|
||||
generator=None,
|
||||
latents=None,
|
||||
)
|
||||
|
||||
if self.pipeline.connectors.device != self.transformer.device:
|
||||
self.pipeline.connectors.to(self.transformer.device)
|
||||
|
||||
# TODO this is how diffusers does this on inference, not sure I understand why, check this
|
||||
additive_attention_mask = (
|
||||
1 - text_embeddings.attention_mask.to(self.transformer.dtype)
|
||||
) * -1000000.0
|
||||
(
|
||||
connector_prompt_embeds,
|
||||
connector_audio_prompt_embeds,
|
||||
connector_attention_mask,
|
||||
) = self.pipeline.connectors(
|
||||
text_embeddings.text_embeds, additive_attention_mask, additive_mask=True
|
||||
)
|
||||
|
||||
# compute video and audio positional ids
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
packed_latents.shape[0],
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
packed_latents.device,
|
||||
fps=frame_rate,
|
||||
)
|
||||
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
|
||||
audio_latents.shape[0], audio_num_frames, audio_latents.device
|
||||
)
|
||||
|
||||
noise_pred_video, noise_pred_audio = self.transformer(
|
||||
hidden_states=packed_latents,
|
||||
audio_hidden_states=audio_latents.to(self.transformer.dtype),
|
||||
encoder_hidden_states=connector_prompt_embeds,
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=None,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# add audio latent to batch if we had audio
|
||||
if batch.audio_target is not None:
|
||||
batch.audio_pred = noise_pred_audio
|
||||
|
||||
unpacked_output = self.pipeline._unpack_latents(
|
||||
latents=noise_pred_video,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
patch_size=self.pipeline.transformer_spatial_patch_size,
|
||||
patch_size_t=self.pipeline.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
return unpacked_output
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
|
||||
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
device=self.device_torch,
|
||||
)
|
||||
pe = PromptEmbeds([prompt_embeds, None])
|
||||
pe.attention_mask = prompt_attention_mask
|
||||
return pe
|
||||
|
||||
def get_model_has_grad(self):
|
||||
return False
|
||||
|
||||
def get_te_has_grad(self):
|
||||
return False
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
transformer: LTX2VideoTransformer3DModel = unwrap_model(self.model)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_path, "transformer"),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
meta_path = os.path.join(output_path, "aitk_meta.yaml")
|
||||
with open(meta_path, "w") as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get("noise")
|
||||
batch = kwargs.get("batch")
|
||||
return (noise - batch.latents).detach()
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "ltx2"
|
||||
|
||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||
return ["transformer_blocks"]
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("transformer.", "diffusion_model.")
|
||||
new_sd[new_key] = value
|
||||
new_sd = convert_lora_diffusers_to_original(new_sd)
|
||||
return new_sd
|
||||
|
||||
def convert_lora_weights_before_load(self, state_dict):
|
||||
state_dict = convert_lora_original_to_diffusers(state_dict)
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("diffusion_model.", "transformer.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
Reference in New Issue
Block a user