mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
904 lines
35 KiB
Python
904 lines
35 KiB
Python
from functools import partial
|
|
import os
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
import torchaudio
|
|
from transformers import Gemma3Config
|
|
import yaml
|
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
|
from toolkit.models.base_model import BaseModel
|
|
from toolkit.basic import flush
|
|
from toolkit.prompt_utils import PromptEmbeds
|
|
from toolkit.samplers.custom_flowmatch_sampler import (
|
|
CustomFlowMatchEulerDiscreteScheduler,
|
|
)
|
|
from accelerate import init_empty_weights
|
|
from toolkit.accelerator import unwrap_model
|
|
from optimum.quanto import freeze
|
|
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
|
from toolkit.memory_management import MemoryManager
|
|
from safetensors.torch import load_file
|
|
from PIL import Image
|
|
|
|
try:
|
|
from diffusers import LTX2Pipeline, LTX2ImageToVideoPipeline
|
|
from diffusers.models.autoencoders import (
|
|
AutoencoderKLLTX2Audio,
|
|
AutoencoderKLLTX2Video,
|
|
)
|
|
from diffusers.models.transformers import LTX2VideoTransformer3DModel
|
|
from diffusers.pipelines.ltx2.export_utils import encode_video
|
|
from transformers import (
|
|
Gemma3ForConditionalGeneration,
|
|
GemmaTokenizerFast,
|
|
)
|
|
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
|
from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors
|
|
from .convert_ltx2_to_diffusers import (
|
|
get_model_state_dict_from_combined_ckpt,
|
|
convert_ltx2_transformer,
|
|
convert_ltx2_video_vae,
|
|
convert_ltx2_audio_vae,
|
|
convert_ltx2_vocoder,
|
|
convert_ltx2_connectors,
|
|
dequantize_state_dict,
|
|
convert_comfy_gemma3_to_transformers,
|
|
convert_lora_original_to_diffusers,
|
|
convert_lora_diffusers_to_original,
|
|
)
|
|
except ImportError as e:
|
|
print("Diffusers import error:", e)
|
|
raise ImportError(
|
|
"Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt"
|
|
)
|
|
|
|
|
|
scheduler_config = {
|
|
"base_image_seq_len": 1024,
|
|
"base_shift": 0.95,
|
|
"invert_sigmas": False,
|
|
"max_image_seq_len": 4096,
|
|
"max_shift": 2.05,
|
|
"num_train_timesteps": 1000,
|
|
"shift": 1.0,
|
|
"shift_terminal": 0.1,
|
|
"stochastic_sampling": False,
|
|
"time_shift_type": "exponential",
|
|
"use_beta_sigmas": False,
|
|
"use_dynamic_shifting": True,
|
|
"use_exponential_sigmas": False,
|
|
"use_karras_sigmas": False,
|
|
}
|
|
|
|
dit_prefix = "model.diffusion_model."
|
|
vae_prefix = "vae."
|
|
audio_vae_prefix = "audio_vae."
|
|
vocoder_prefix = "vocoder."
|
|
|
|
|
|
def new_save_image_function(
|
|
self: GenerateImageConfig,
|
|
image, # will contain a dict that can be dumped ditectly into encode_video, just add output_path to it.
|
|
count: int = 0,
|
|
max_count: int = 0,
|
|
**kwargs,
|
|
):
|
|
# this replaces gen image config save image function so we can save the video with sound from ltx2
|
|
image["output_path"] = self.get_image_path(count, max_count)
|
|
# make sample directory if it does not exist
|
|
os.makedirs(os.path.dirname(image["output_path"]), exist_ok=True)
|
|
encode_video(**image)
|
|
flush()
|
|
|
|
|
|
def blank_log_image_function(self, *args, **kwargs):
|
|
# todo handle wandb logging of videos with audio
|
|
return
|
|
|
|
|
|
class AudioProcessor(torch.nn.Module):
|
|
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
|
|
|
def __init__(
|
|
self,
|
|
sample_rate: int,
|
|
mel_bins: int,
|
|
mel_hop_length: int,
|
|
n_fft: int,
|
|
) -> None:
|
|
super().__init__()
|
|
self.sample_rate = sample_rate
|
|
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
|
sample_rate=sample_rate,
|
|
n_fft=n_fft,
|
|
win_length=n_fft,
|
|
hop_length=mel_hop_length,
|
|
f_min=0.0,
|
|
f_max=sample_rate / 2.0,
|
|
n_mels=mel_bins,
|
|
window_fn=torch.hann_window,
|
|
center=True,
|
|
pad_mode="reflect",
|
|
power=1.0,
|
|
mel_scale="slaney",
|
|
norm="slaney",
|
|
)
|
|
|
|
def resample_waveform(
|
|
self,
|
|
waveform: torch.Tensor,
|
|
source_rate: int,
|
|
target_rate: int,
|
|
) -> torch.Tensor:
|
|
"""Resample waveform to target sample rate if needed."""
|
|
if source_rate == target_rate:
|
|
return waveform
|
|
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
|
return resampled.to(device=waveform.device, dtype=waveform.dtype)
|
|
|
|
def waveform_to_mel(
|
|
self,
|
|
waveform: torch.Tensor,
|
|
waveform_sample_rate: int,
|
|
) -> torch.Tensor:
|
|
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
|
|
waveform = self.resample_waveform(
|
|
waveform, waveform_sample_rate, self.sample_rate
|
|
)
|
|
|
|
mel = self.mel_transform(waveform)
|
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
|
|
|
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
|
|
return mel.permute(0, 1, 3, 2).contiguous()
|
|
|
|
|
|
class LTX2Model(BaseModel):
|
|
arch = "ltx2"
|
|
|
|
def __init__(
|
|
self,
|
|
device,
|
|
model_config: ModelConfig,
|
|
dtype="bf16",
|
|
custom_pipeline=None,
|
|
noise_scheduler=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
|
)
|
|
self.is_flow_matching = True
|
|
self.is_transformer = True
|
|
self.target_lora_modules = ["LTX2VideoTransformer3DModel"]
|
|
# defines if the model supports model paths. Only some will
|
|
self.supports_model_paths = True
|
|
# use the new format on this new model by default
|
|
self.use_old_lokr_format = False
|
|
self.audio_processor = None
|
|
|
|
# static method to get the noise scheduler
|
|
@staticmethod
|
|
def get_train_scheduler():
|
|
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
|
|
|
def get_bucket_divisibility(self):
|
|
return 32
|
|
|
|
def load_model(self):
|
|
dtype = self.torch_dtype
|
|
self.print_and_status_update("Loading LTX2 model")
|
|
model_path = self.model_config.name_or_path
|
|
base_model_path = self.model_config.extras_name_or_path
|
|
|
|
combined_state_dict = None
|
|
|
|
self.print_and_status_update("Loading transformer")
|
|
# if we have a safetensors file it is a mono checkpoint
|
|
if os.path.exists(model_path) and model_path.endswith(".safetensors"):
|
|
combined_state_dict = load_file(model_path)
|
|
combined_state_dict = dequantize_state_dict(combined_state_dict)
|
|
|
|
if combined_state_dict is not None:
|
|
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(
|
|
combined_state_dict, dit_prefix
|
|
)
|
|
transformer = convert_ltx2_transformer(original_dit_ckpt)
|
|
transformer = transformer.to(dtype)
|
|
else:
|
|
transformer_path = model_path
|
|
transformer_subfolder = "transformer"
|
|
if os.path.exists(transformer_path):
|
|
transformer_subfolder = None
|
|
transformer_path = os.path.join(transformer_path, "transformer")
|
|
# check if the path is a full checkpoint.
|
|
te_folder_path = os.path.join(model_path, "text_encoder")
|
|
# if we have the te, this folder is a full checkpoint, use it as the base
|
|
if os.path.exists(te_folder_path):
|
|
base_model_path = model_path
|
|
|
|
transformer = LTX2VideoTransformer3DModel.from_pretrained(
|
|
transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype
|
|
)
|
|
|
|
if self.model_config.quantize:
|
|
self.print_and_status_update("Quantizing Transformer")
|
|
quantize_model(self, transformer)
|
|
flush()
|
|
|
|
if (
|
|
self.model_config.layer_offloading
|
|
and self.model_config.layer_offloading_transformer_percent > 0
|
|
):
|
|
ignore_modules = []
|
|
for block in transformer.transformer_blocks:
|
|
ignore_modules.append(block.scale_shift_table)
|
|
ignore_modules.append(block.audio_scale_shift_table)
|
|
ignore_modules.append(block.video_a2v_cross_attn_scale_shift_table)
|
|
ignore_modules.append(block.audio_a2v_cross_attn_scale_shift_table)
|
|
ignore_modules.append(transformer.scale_shift_table)
|
|
ignore_modules.append(transformer.audio_scale_shift_table)
|
|
MemoryManager.attach(
|
|
transformer,
|
|
self.device_torch,
|
|
offload_percent=self.model_config.layer_offloading_transformer_percent,
|
|
ignore_modules=ignore_modules,
|
|
)
|
|
|
|
if self.model_config.low_vram:
|
|
self.print_and_status_update("Moving transformer to CPU")
|
|
transformer.to("cpu")
|
|
|
|
flush()
|
|
|
|
self.print_and_status_update("Loading text encoder")
|
|
if (
|
|
self.model_config.te_name_or_path is not None
|
|
and self.model_config.te_name_or_path.endswith(".safetensors")
|
|
):
|
|
# load from comfyui gemma3 checkpoint
|
|
tokenizer = GemmaTokenizerFast.from_pretrained(
|
|
"Lightricks/LTX-2", subfolder="tokenizer"
|
|
)
|
|
|
|
with init_empty_weights():
|
|
text_encoder = Gemma3ForConditionalGeneration(
|
|
Gemma3Config(
|
|
**{
|
|
"boi_token_index": 255999,
|
|
"bos_token_id": 2,
|
|
"eoi_token_index": 256000,
|
|
"eos_token_id": 106,
|
|
"image_token_index": 262144,
|
|
"initializer_range": 0.02,
|
|
"mm_tokens_per_image": 256,
|
|
"model_type": "gemma3",
|
|
"pad_token_id": 0,
|
|
"text_config": {
|
|
"attention_bias": False,
|
|
"attention_dropout": 0.0,
|
|
"attn_logit_softcapping": None,
|
|
"cache_implementation": "hybrid",
|
|
"final_logit_softcapping": None,
|
|
"head_dim": 256,
|
|
"hidden_activation": "gelu_pytorch_tanh",
|
|
"hidden_size": 3840,
|
|
"initializer_range": 0.02,
|
|
"intermediate_size": 15360,
|
|
"max_position_embeddings": 131072,
|
|
"model_type": "gemma3_text",
|
|
"num_attention_heads": 16,
|
|
"num_hidden_layers": 48,
|
|
"num_key_value_heads": 8,
|
|
"query_pre_attn_scalar": 256,
|
|
"rms_norm_eps": 1e-06,
|
|
"rope_local_base_freq": 10000,
|
|
"rope_scaling": {"factor": 8.0, "rope_type": "linear"},
|
|
"rope_theta": 1000000,
|
|
"sliding_window": 1024,
|
|
"sliding_window_pattern": 6,
|
|
"torch_dtype": "bfloat16",
|
|
"use_cache": True,
|
|
"vocab_size": 262208,
|
|
},
|
|
"torch_dtype": "bfloat16",
|
|
"transformers_version": "4.51.3",
|
|
"unsloth_fixed": True,
|
|
"vision_config": {
|
|
"attention_dropout": 0.0,
|
|
"hidden_act": "gelu_pytorch_tanh",
|
|
"hidden_size": 1152,
|
|
"image_size": 896,
|
|
"intermediate_size": 4304,
|
|
"layer_norm_eps": 1e-06,
|
|
"model_type": "siglip_vision_model",
|
|
"num_attention_heads": 16,
|
|
"num_channels": 3,
|
|
"num_hidden_layers": 27,
|
|
"patch_size": 14,
|
|
"torch_dtype": "bfloat16",
|
|
"vision_use_head": False,
|
|
},
|
|
}
|
|
)
|
|
)
|
|
te_state_dict = load_file(self.model_config.te_name_or_path)
|
|
te_state_dict = convert_comfy_gemma3_to_transformers(te_state_dict)
|
|
for key in te_state_dict:
|
|
te_state_dict[key] = te_state_dict[key].to(dtype)
|
|
|
|
text_encoder.load_state_dict(te_state_dict, assign=True, strict=True)
|
|
del te_state_dict
|
|
flush()
|
|
else:
|
|
if self.model_config.te_name_or_path is not None:
|
|
te_path = self.model_config.te_name_or_path
|
|
else:
|
|
te_path = base_model_path
|
|
tokenizer = GemmaTokenizerFast.from_pretrained(
|
|
te_path, subfolder="tokenizer"
|
|
)
|
|
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
|
|
te_path, subfolder="text_encoder", dtype=dtype
|
|
)
|
|
|
|
# remove the vision tower
|
|
text_encoder.model.vision_tower = None
|
|
flush()
|
|
|
|
if (
|
|
self.model_config.layer_offloading
|
|
and self.model_config.layer_offloading_text_encoder_percent > 0
|
|
):
|
|
MemoryManager.attach(
|
|
text_encoder,
|
|
self.device_torch,
|
|
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
|
ignore_modules=[
|
|
text_encoder.model.language_model.base_model.embed_tokens
|
|
],
|
|
)
|
|
|
|
text_encoder.to(self.device_torch, dtype=dtype)
|
|
flush()
|
|
|
|
if self.model_config.quantize_te:
|
|
self.print_and_status_update("Quantizing Text Encoder")
|
|
quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
|
|
freeze(text_encoder)
|
|
flush()
|
|
|
|
self.print_and_status_update("Loading VAEs and other components")
|
|
if combined_state_dict is not None:
|
|
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(
|
|
combined_state_dict, vae_prefix
|
|
)
|
|
vae = convert_ltx2_video_vae(original_vae_ckpt).to(dtype)
|
|
del original_vae_ckpt
|
|
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(
|
|
combined_state_dict, audio_vae_prefix
|
|
)
|
|
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt).to(dtype)
|
|
del original_audio_vae_ckpt
|
|
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(
|
|
combined_state_dict, dit_prefix
|
|
)
|
|
connectors = convert_ltx2_connectors(original_connectors_ckpt).to(dtype)
|
|
del original_connectors_ckpt
|
|
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(
|
|
combined_state_dict, vocoder_prefix
|
|
)
|
|
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt).to(dtype)
|
|
del original_vocoder_ckpt
|
|
del combined_state_dict
|
|
flush()
|
|
else:
|
|
vae = AutoencoderKLLTX2Video.from_pretrained(
|
|
base_model_path, subfolder="vae", torch_dtype=dtype
|
|
)
|
|
audio_vae = AutoencoderKLLTX2Audio.from_pretrained(
|
|
base_model_path, subfolder="audio_vae", torch_dtype=dtype
|
|
)
|
|
|
|
connectors = LTX2TextConnectors.from_pretrained(
|
|
base_model_path, subfolder="connectors", torch_dtype=dtype
|
|
)
|
|
|
|
vocoder = LTX2Vocoder.from_pretrained(
|
|
base_model_path, subfolder="vocoder", torch_dtype=dtype
|
|
)
|
|
|
|
self.noise_scheduler = LTX2Model.get_train_scheduler()
|
|
|
|
self.print_and_status_update("Making pipe")
|
|
|
|
pipe: LTX2Pipeline = LTX2Pipeline(
|
|
scheduler=self.noise_scheduler,
|
|
vae=vae,
|
|
audio_vae=audio_vae,
|
|
text_encoder=None,
|
|
tokenizer=tokenizer,
|
|
connectors=connectors,
|
|
transformer=None,
|
|
vocoder=vocoder,
|
|
)
|
|
# for quantization, it works best to do these after making the pipe
|
|
pipe.text_encoder = text_encoder
|
|
pipe.transformer = transformer
|
|
|
|
self.print_and_status_update("Preparing Model")
|
|
|
|
text_encoder = [pipe.text_encoder]
|
|
tokenizer = [pipe.tokenizer]
|
|
|
|
# leave it on cpu for now
|
|
if not self.low_vram:
|
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
|
|
|
flush()
|
|
# just to make sure everything is on the right device and dtype
|
|
text_encoder[0].to(self.device_torch)
|
|
text_encoder[0].requires_grad_(False)
|
|
text_encoder[0].eval()
|
|
flush()
|
|
|
|
# save it to the model class
|
|
self.vae = vae
|
|
self.text_encoder = text_encoder # list of text encoders
|
|
self.tokenizer = tokenizer # list of tokenizers
|
|
self.model = pipe.transformer
|
|
self.pipeline = pipe
|
|
|
|
self.audio_processor = AudioProcessor(
|
|
sample_rate=pipe.audio_sampling_rate,
|
|
mel_bins=audio_vae.config.mel_bins,
|
|
mel_hop_length=pipe.audio_hop_length,
|
|
n_fft=1024, # todo get this from vae if we can, I couldnt find it.
|
|
).to(self.device_torch, dtype=torch.float32)
|
|
|
|
self.print_and_status_update("Model Loaded")
|
|
|
|
@torch.no_grad()
|
|
def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None):
|
|
if device is None:
|
|
device = self.vae_device_torch
|
|
if dtype is None:
|
|
dtype = self.vae_torch_dtype
|
|
|
|
if self.vae.device == torch.device("cpu"):
|
|
self.vae.to(device)
|
|
self.vae.eval()
|
|
self.vae.requires_grad_(False)
|
|
|
|
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
|
|
|
# Normalize shapes
|
|
norm_images = []
|
|
for image in image_list:
|
|
if image.ndim == 3:
|
|
# (C, H, W) -> (C, 1, H, W)
|
|
norm_images.append(image.unsqueeze(1))
|
|
elif image.ndim == 4:
|
|
# (T, C, H, W) -> (C, T, H, W)
|
|
norm_images.append(image.permute(1, 0, 2, 3))
|
|
else:
|
|
raise ValueError(f"Invalid image shape: {image.shape}")
|
|
|
|
# Stack to (B, C, T, H, W)
|
|
images = torch.stack(norm_images)
|
|
|
|
latents = self.vae.encode(images).latent_dist.mode()
|
|
|
|
# Normalize latents across the channel dimension [B, C, F, H, W]
|
|
scaling_factor = 1.0
|
|
latents_mean = self.pipeline.vae.latents_mean.view(1, -1, 1, 1, 1).to(
|
|
latents.device, latents.dtype
|
|
)
|
|
latents_std = self.pipeline.vae.latents_std.view(1, -1, 1, 1, 1).to(
|
|
latents.device, latents.dtype
|
|
)
|
|
latents = (latents - latents_mean) * scaling_factor / latents_std
|
|
|
|
return latents.to(device, dtype=dtype)
|
|
|
|
def get_generation_pipeline(self):
|
|
scheduler = LTX2Model.get_train_scheduler()
|
|
|
|
pipeline: LTX2Pipeline = LTX2Pipeline(
|
|
scheduler=scheduler,
|
|
vae=unwrap_model(self.pipeline.vae),
|
|
audio_vae=unwrap_model(self.pipeline.audio_vae),
|
|
text_encoder=None,
|
|
tokenizer=unwrap_model(self.pipeline.tokenizer),
|
|
connectors=unwrap_model(self.pipeline.connectors),
|
|
transformer=None,
|
|
vocoder=unwrap_model(self.pipeline.vocoder),
|
|
)
|
|
pipeline.transformer = unwrap_model(self.model)
|
|
pipeline.text_encoder = unwrap_model(self.text_encoder[0])
|
|
|
|
pipeline = pipeline.to(self.device_torch)
|
|
|
|
return pipeline
|
|
|
|
def generate_single_image(
|
|
self,
|
|
pipeline: LTX2Pipeline,
|
|
gen_config: GenerateImageConfig,
|
|
conditional_embeds: PromptEmbeds,
|
|
unconditional_embeds: PromptEmbeds,
|
|
generator: torch.Generator,
|
|
extra: dict,
|
|
):
|
|
if self.model.device == torch.device("cpu"):
|
|
self.model.to(self.device_torch)
|
|
|
|
# handle control image
|
|
if gen_config.ctrl_img is not None:
|
|
# switch to image to video pipeline
|
|
pipeline = LTX2ImageToVideoPipeline(
|
|
scheduler=pipeline.scheduler,
|
|
vae=pipeline.vae,
|
|
audio_vae=pipeline.audio_vae,
|
|
text_encoder=pipeline.text_encoder,
|
|
tokenizer=pipeline.tokenizer,
|
|
connectors=pipeline.connectors,
|
|
transformer=pipeline.transformer,
|
|
vocoder=pipeline.vocoder,
|
|
)
|
|
|
|
is_video = gen_config.num_frames > 1
|
|
# override the generate single image to handle video + audio generation
|
|
if is_video:
|
|
gen_config._orig_save_image_function = gen_config.save_image
|
|
gen_config.save_image = partial(new_save_image_function, gen_config)
|
|
gen_config.log_image = partial(blank_log_image_function, gen_config)
|
|
# set output extension to mp4
|
|
gen_config.output_ext = "mp4"
|
|
|
|
# reactivate progress bar since this is slooooow
|
|
pipeline.set_progress_bar_config(disable=False)
|
|
pipeline = pipeline.to(self.device_torch)
|
|
|
|
# make sure dimensions are valid
|
|
bd = self.get_bucket_divisibility()
|
|
gen_config.height = (gen_config.height // bd) * bd
|
|
gen_config.width = (gen_config.width // bd) * bd
|
|
|
|
# handle control image
|
|
if gen_config.ctrl_img is not None:
|
|
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
|
# resize the control image
|
|
control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS)
|
|
# add the control image to the extra dict
|
|
extra["image"] = control_img
|
|
|
|
# frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc.
|
|
if gen_config.num_frames != 1:
|
|
if (gen_config.num_frames - 1) % 8 != 0:
|
|
gen_config.num_frames = ((gen_config.num_frames - 1) // 8) * 8 + 1
|
|
|
|
if self.low_vram:
|
|
# set vae to tile decode
|
|
pipeline.vae.enable_tiling(
|
|
tile_sample_min_height=256,
|
|
tile_sample_min_width=256,
|
|
tile_sample_min_num_frames=8,
|
|
tile_sample_stride_height=224,
|
|
tile_sample_stride_width=224,
|
|
tile_sample_stride_num_frames=4,
|
|
)
|
|
|
|
video, audio = pipeline(
|
|
prompt_embeds=conditional_embeds.text_embeds.to(
|
|
self.device_torch, dtype=self.torch_dtype
|
|
),
|
|
prompt_attention_mask=conditional_embeds.attention_mask.to(
|
|
self.device_torch
|
|
),
|
|
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
|
|
self.device_torch, dtype=self.torch_dtype
|
|
),
|
|
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(
|
|
self.device_torch
|
|
),
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
latents=gen_config.latents,
|
|
num_frames=gen_config.num_frames,
|
|
generator=generator,
|
|
return_dict=False,
|
|
output_type="np" if is_video else "pil",
|
|
**extra,
|
|
)
|
|
if self.low_vram:
|
|
# Restore no tiling
|
|
pipeline.vae.use_tiling = False
|
|
|
|
if is_video:
|
|
# redurn as a dict, we will handle it with an override function
|
|
video = (video * 255).round().astype("uint8")
|
|
video = torch.from_numpy(video)
|
|
return {
|
|
"video": video[0],
|
|
"fps": gen_config.fps,
|
|
"audio": audio[0].float().cpu(),
|
|
"audio_sample_rate": pipeline.vocoder.config.output_sampling_rate, # should be 24000
|
|
"output_path": None,
|
|
}
|
|
else:
|
|
# shape = [1, frames, channels, height, width]
|
|
# make sure this is right
|
|
video = video[0] # list of pil images
|
|
audio = audio[0] # tensor
|
|
if gen_config.num_frames > 1:
|
|
return video # return the frames.
|
|
else:
|
|
# get just the first image
|
|
img = video[0]
|
|
return img
|
|
|
|
def encode_audio(self, batch: "DataLoaderBatchDTO"):
|
|
if self.pipeline.audio_vae.device == torch.device("cpu"):
|
|
self.pipeline.audio_vae.to(self.device_torch)
|
|
|
|
output_tensor = None
|
|
audio_num_frames = None
|
|
|
|
# do them seperatly for now
|
|
for audio_data in batch.audio_data:
|
|
waveform = audio_data["waveform"].to(
|
|
device=self.device_torch, dtype=torch.float32
|
|
)
|
|
sample_rate = audio_data["sample_rate"]
|
|
|
|
# Add batch dimension if needed: [channels, samples] -> [batch, channels, samples]
|
|
if waveform.dim() == 2:
|
|
waveform = waveform.unsqueeze(0)
|
|
|
|
if waveform.shape[1] == 1:
|
|
# make sure it is stereo
|
|
waveform = waveform.repeat(1, 2, 1)
|
|
|
|
# Convert waveform to mel spectrogram using AudioProcessor
|
|
mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate)
|
|
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
|
|
|
|
# Encode mel spectrogram to latents
|
|
latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode()
|
|
|
|
if audio_num_frames is None:
|
|
audio_num_frames = latents.shape[2] #(latents is [B, C, T, F])
|
|
|
|
packed_latents = self.pipeline._pack_audio_latents(
|
|
latents,
|
|
# patch_size=self.pipeline.transformer.config.audio_patch_size,
|
|
# patch_size_t=self.pipeline.transformer.config.audio_patch_size_t,
|
|
) # [B, L, C * M]
|
|
if output_tensor is None:
|
|
output_tensor = packed_latents
|
|
else:
|
|
output_tensor = torch.cat([output_tensor, packed_latents], dim=0)
|
|
|
|
# normalize latents, opposite of (latents * latents_std) + latents_mean
|
|
latents_mean = self.pipeline.audio_vae.latents_mean
|
|
latents_std = self.pipeline.audio_vae.latents_std
|
|
output_tensor = (output_tensor - latents_mean) / latents_std
|
|
return output_tensor, audio_num_frames
|
|
|
|
def get_noise_prediction(
|
|
self,
|
|
latent_model_input: torch.Tensor,
|
|
timestep: torch.Tensor, # 0 to 1000 scale
|
|
text_embeddings: PromptEmbeds,
|
|
batch: "DataLoaderBatchDTO" = None,
|
|
**kwargs,
|
|
):
|
|
with torch.no_grad():
|
|
if self.model.device == torch.device("cpu"):
|
|
self.model.to(self.device_torch)
|
|
|
|
batch_size, C, latent_num_frames, latent_height, latent_width = (
|
|
latent_model_input.shape
|
|
)
|
|
|
|
video_timestep = timestep.clone()
|
|
|
|
# i2v from first frame
|
|
if batch.dataset_config.do_i2v:
|
|
# videos come in (bs, num_frames, channels, height, width)
|
|
# images come in (bs, channels, height, width)
|
|
frames = batch.tensor
|
|
if len(frames.shape) == 4:
|
|
first_frames = frames
|
|
elif len(frames.shape) == 5:
|
|
first_frames = frames[:, 0]
|
|
else:
|
|
raise ValueError(f"Unknown frame shape {frames.shape}")
|
|
# first frame doesnt have time dim, add it back
|
|
init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype)
|
|
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
|
|
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
|
|
# First condition is image latents and those should be kept clean.
|
|
conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype)
|
|
conditioning_mask[:, :, 0] = 1.0
|
|
|
|
# use conditioning mask to replace latents
|
|
latent_model_input = (
|
|
latent_model_input * (1 - conditioning_mask)
|
|
+ init_latents * conditioning_mask
|
|
)
|
|
|
|
# set video timestep
|
|
video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
|
|
|
# todo get this somehow
|
|
frame_rate = 24
|
|
# check frame dimension
|
|
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
|
packed_latents = self.pipeline._pack_latents(
|
|
latent_model_input,
|
|
patch_size=self.pipeline.transformer_spatial_patch_size,
|
|
patch_size_t=self.pipeline.transformer_temporal_patch_size,
|
|
)
|
|
|
|
if batch.audio_tensor is not None:
|
|
# use audio from the batch if available
|
|
#(1, 190, 128)
|
|
raw_audio_latents, audio_num_frames = self.encode_audio(batch)
|
|
# add the audio targets to the batch for loss calculation later
|
|
audio_noise = torch.randn_like(raw_audio_latents)
|
|
batch.audio_target = (audio_noise - raw_audio_latents).detach()
|
|
audio_latents = self.add_noise(
|
|
raw_audio_latents,
|
|
audio_noise,
|
|
timestep,
|
|
).to(self.device_torch, dtype=self.torch_dtype)
|
|
|
|
else:
|
|
# no audio
|
|
num_mel_bins = self.pipeline.audio_vae.config.mel_bins
|
|
# latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
|
num_channels_latents_audio = (
|
|
self.pipeline.audio_vae.config.latent_channels
|
|
)
|
|
# audio latents are (1, 126, 128), audio_num_frames = 126
|
|
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
|
|
batch_size,
|
|
num_channels_latents=num_channels_latents_audio,
|
|
num_mel_bins=num_mel_bins,
|
|
num_frames=batch.tensor.shape[1],
|
|
frame_rate=frame_rate,
|
|
sampling_rate=self.pipeline.audio_sampling_rate,
|
|
hop_length=self.pipeline.audio_hop_length,
|
|
dtype=torch.float32,
|
|
device=self.transformer.device,
|
|
generator=None,
|
|
latents=None,
|
|
)
|
|
|
|
if self.pipeline.connectors.device != self.transformer.device:
|
|
self.pipeline.connectors.to(self.transformer.device)
|
|
|
|
# TODO this is how diffusers does this on inference, not sure I understand why, check this
|
|
additive_attention_mask = (
|
|
1 - text_embeddings.attention_mask.to(self.transformer.dtype)
|
|
) * -1000000.0
|
|
(
|
|
connector_prompt_embeds,
|
|
connector_audio_prompt_embeds,
|
|
connector_attention_mask,
|
|
) = self.pipeline.connectors(
|
|
text_embeddings.text_embeds, additive_attention_mask, additive_mask=True
|
|
)
|
|
|
|
# compute video and audio positional ids
|
|
video_coords = self.transformer.rope.prepare_video_coords(
|
|
packed_latents.shape[0],
|
|
latent_num_frames,
|
|
latent_height,
|
|
latent_width,
|
|
packed_latents.device,
|
|
fps=frame_rate,
|
|
)
|
|
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
|
|
audio_latents.shape[0], audio_num_frames, audio_latents.device
|
|
)
|
|
|
|
noise_pred_video, noise_pred_audio = self.transformer(
|
|
hidden_states=packed_latents,
|
|
audio_hidden_states=audio_latents.to(self.transformer.dtype),
|
|
encoder_hidden_states=connector_prompt_embeds,
|
|
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
|
timestep=video_timestep,
|
|
audio_timestep=timestep,
|
|
encoder_attention_mask=connector_attention_mask,
|
|
audio_encoder_attention_mask=connector_attention_mask,
|
|
num_frames=latent_num_frames,
|
|
height=latent_height,
|
|
width=latent_width,
|
|
fps=frame_rate,
|
|
audio_num_frames=audio_num_frames,
|
|
video_coords=video_coords,
|
|
audio_coords=audio_coords,
|
|
# rope_interpolation_scale=rope_interpolation_scale,
|
|
attention_kwargs=None,
|
|
return_dict=False,
|
|
)
|
|
|
|
# add audio latent to batch if we had audio
|
|
if batch.audio_target is not None:
|
|
batch.audio_pred = noise_pred_audio
|
|
|
|
unpacked_output = self.pipeline._unpack_latents(
|
|
latents=noise_pred_video,
|
|
num_frames=latent_num_frames,
|
|
height=latent_height,
|
|
width=latent_width,
|
|
patch_size=self.pipeline.transformer_spatial_patch_size,
|
|
patch_size_t=self.pipeline.transformer_temporal_patch_size,
|
|
)
|
|
|
|
return unpacked_output
|
|
|
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
|
if self.pipeline.text_encoder.device != self.device_torch:
|
|
self.pipeline.text_encoder.to(self.device_torch)
|
|
|
|
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
|
|
prompt,
|
|
do_classifier_free_guidance=False,
|
|
device=self.device_torch,
|
|
)
|
|
pe = PromptEmbeds([prompt_embeds, None])
|
|
pe.attention_mask = prompt_attention_mask
|
|
return pe
|
|
|
|
def get_model_has_grad(self):
|
|
return False
|
|
|
|
def get_te_has_grad(self):
|
|
return False
|
|
|
|
def save_model(self, output_path, meta, save_dtype):
|
|
transformer: LTX2VideoTransformer3DModel = unwrap_model(self.model)
|
|
transformer.save_pretrained(
|
|
save_directory=os.path.join(output_path, "transformer"),
|
|
safe_serialization=True,
|
|
)
|
|
|
|
meta_path = os.path.join(output_path, "aitk_meta.yaml")
|
|
with open(meta_path, "w") as f:
|
|
yaml.dump(meta, f)
|
|
|
|
def get_loss_target(self, *args, **kwargs):
|
|
noise = kwargs.get("noise")
|
|
batch = kwargs.get("batch")
|
|
return (noise - batch.latents).detach()
|
|
|
|
def get_base_model_version(self):
|
|
return "ltx2"
|
|
|
|
def get_transformer_block_names(self) -> Optional[List[str]]:
|
|
return ["transformer_blocks"]
|
|
|
|
def convert_lora_weights_before_save(self, state_dict):
|
|
new_sd = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace("transformer.", "diffusion_model.")
|
|
new_sd[new_key] = value
|
|
new_sd = convert_lora_diffusers_to_original(new_sd)
|
|
return new_sd
|
|
|
|
def convert_lora_weights_before_load(self, state_dict):
|
|
state_dict = convert_lora_original_to_diffusers(state_dict)
|
|
new_sd = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace("diffusion_model.", "transformer.")
|
|
new_sd[new_key] = value
|
|
return new_sd
|