[diffusion] model: LTX-2 Support PR3 (#19151)

This commit is contained in:
GMI Xiao Jin
2026-02-24 00:55:28 -08:00
committed by GitHub
parent b2166a2f33
commit fcfd964d7d
11 changed files with 237 additions and 70 deletions

View File

@@ -19,7 +19,7 @@ dependencies = [
"aiohttp",
"apache-tvm-ffi>=0.1.5,<0.2",
"anthropic>=0.20.0",
"av ; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'armv7l')",
"av",
"blobfile==3.0.0",
"build",
"compressed-tensors",

View File

@@ -19,6 +19,7 @@ dependencies = [
"aiohttp",
"anthropic>=0.20.0",
"blobfile==3.0.0",
"av",
"build",
"compressed-tensors",
"decord2",

View File

@@ -21,6 +21,7 @@ runtime_common = [
"aiohttp",
"anthropic>=0.20.0",
"blobfile==3.0.0",
"av",
"build",
"compressed-tensors",
"decord2",

View File

@@ -164,7 +164,7 @@ class LTX2ArchConfig(DiTArchConfig):
self.audio_num_attention_heads * self.audio_attention_head_dim
)
if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [2048]
self.audio_positional_embedding_max_pos = [20]
@dataclass

View File

@@ -2,7 +2,6 @@ import dataclasses
from dataclasses import field
from typing import Callable
import numpy as np
import torch
from sglang.multimodal_gen.configs.models.dits.ltx_2 import LTX2Config
@@ -139,6 +138,23 @@ class LTX2PipelineConfig(PipelineConfig):
def vae_temporal_compression(self):
return getattr(self.vae_config.arch_config, "temporal_compression_ratio", 8)
def prepare_latent_shape(self, batch, batch_size, num_frames):
"""Return packed latent shape [B, seq, C] directly."""
height = batch.height // self.vae_scale_factor
width = batch.width // self.vae_scale_factor
post_patch_num_frames = num_frames // self.patch_size_t
post_patch_height = height // self.patch_size
post_patch_width = width // self.patch_size
seq_len = post_patch_num_frames * post_patch_height * post_patch_width
num_channels = (
self.in_channels * self.patch_size_t * self.patch_size * self.patch_size
)
shape = (batch_size, seq_len, num_channels)
return shape
def prepare_audio_latent_shape(self, batch, batch_size, num_frames):
# Adapted from diffusers pipeline prepare_audio_latents
duration_s = num_frames / batch.fps
@@ -159,7 +175,7 @@ class LTX2PipelineConfig(PipelineConfig):
# Default to 8
num_channels_latents = self.audio_vae_config.arch_config.latent_channels
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
shape = (batch_size, latent_length, num_channels_latents * latent_mel_bins)
return shape
@@ -184,7 +200,7 @@ class LTX2PipelineConfig(PipelineConfig):
steps = int(num_inference_steps)
if steps <= 0:
raise ValueError(f"num_inference_steps must be positive, got {steps}")
return np.linspace(1.0, 1.0 / float(steps), steps).tolist()
return [1.0 - i / steps for i in range(steps)]
return sigmas
def tokenize_prompt(self, prompt: list[str], tokenizer, tok_kwargs) -> dict:
@@ -210,6 +226,10 @@ class LTX2PipelineConfig(PipelineConfig):
return text_inputs
def maybe_pack_latents(self, latents, batch_size, batch):
# If already packed (3D shape [B, seq, C]), skip packing
if latents.dim() == 3:
return latents
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
@@ -338,6 +358,10 @@ class LTX2PipelineConfig(PipelineConfig):
return super().gather_latents_for_sp(latents)
def maybe_pack_audio_latents(self, latents, batch_size, batch):
# If already packed (3D shape [B, T, C*F]), skip packing
if latents.dim() == 3:
return latents
# Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins
# We need to pack them if patch_size/patch_size_t are defined for audio (not standard DiT patch size)

View File

@@ -34,7 +34,7 @@ def apply_split_rotary_emb(
# The cos/sin batch dim may only be broadcastable, so take batch size from x
b = x.shape[0]
_, h, t, _ = cos.shape
x = x.reshape(b, t, h, -1).swapaxes(1, 2)
x = x.reshape(b, t, h, -1).transpose(1, 2)
needs_reshape = True
# Split last dim (2*r) into (d=2, r)
@@ -46,7 +46,7 @@ def apply_split_rotary_emb(
r = last // 2
# (..., 2, r)
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
split_x = x.reshape(*x.shape[:-1], 2, r)
first_x = split_x[..., :1, :] # (..., 1, r)
second_x = split_x[..., 1:, :] # (..., 1, r)
@@ -63,7 +63,7 @@ def apply_split_rotary_emb(
out = out.reshape(*out.shape[:-2], last)
if needs_reshape:
out = out.swapaxes(1, 2).reshape(b, t, -1)
out = out.transpose(1, 2).reshape(b, t, -1)
out = out.to(dtype=x_dtype)
return out
@@ -232,6 +232,7 @@ class LTX2RotaryPosEmbed1d(nn.Module):
batch_size: int,
pos: int,
device: Union[str, torch.device],
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Get 1D position ids
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
@@ -297,6 +298,9 @@ class LTX2RotaryPosEmbed1d(nn.Module):
cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
if dtype is not None:
cos_freqs = cos_freqs.to(dtype)
sin_freqs = sin_freqs.to(dtype)
return cos_freqs, sin_freqs
@@ -460,7 +464,9 @@ class LTX2ConnectorTransformer1d(nn.Module):
attention_mask = torch.zeros_like(attention_mask)
# 2. Calculate 1D RoPE positional embeddings
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
rotary_emb = self.rope(
batch_size, seq_len, device=hidden_states.device, dtype=hidden_states.dtype
)
# 3. Run 1D transformer blocks
for block in self.transformer_blocks:

View File

@@ -66,7 +66,7 @@ def apply_split_rotary_emb(
)
r = last // 2
split_x = x.reshape(*x.shape[:-1], 2, r).float()
split_x = x.reshape(*x.shape[:-1], 2, r)
first_x = split_x[..., :1, :]
second_x = split_x[..., 1:, :]
@@ -137,6 +137,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
self.causal_offset = int(causal_offset)
self.modality = modality
self.coords_dtype = torch.bfloat16 if modality == "video" else torch.float32
if self.modality not in ["video", "audio"]:
raise ValueError(
f"Modality {modality} is not supported. Supported modalities are `video` and `audio`."
@@ -243,6 +244,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
device = device or coords.device
num_pos_dims = coords.shape[1]
coords = coords.to(self.coords_dtype)
if coords.ndim == 4:
coords_start, coords_end = coords.chunk(2, dim=-1)
coords = (coords_start + coords_end) / 2.0
@@ -307,7 +309,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
cos_freqs = torch.swapaxes(cos_freq, 1, 2)
sin_freqs = torch.swapaxes(sin_freq, 1, 2)
return cos_freqs, sin_freqs
# Cast to bf16 to match model weights dtype. coords_dtype controls
# intermediate coordinate precision (fp32 for audio) and differs.
return cos_freqs.to(torch.bfloat16), sin_freqs.to(torch.bfloat16)
def rms_norm(x: torch.Tensor, eps: float) -> torch.Tensor:
@@ -1121,7 +1125,9 @@ class LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin):
if hasattr(arch.rope_type, "value")
else str(arch.rope_type)
)
rope_double_precision = bool(getattr(arch, "double_precision_rope", True))
rope_double_precision = bool(
hf_config.get("rope_double_precision", arch.double_precision_rope)
)
causal_offset = int(hf_config.get("causal_offset", 1))
pos_embed_max_pos = int(arch.positional_embedding_max_pos[0])
@@ -1351,27 +1357,30 @@ class LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin):
self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
)
hidden_dtype = hidden_states.dtype
temb_ca_scale_shift, _ = self.av_ca_video_scale_shift_adaln_single(
timestep.flatten()
timestep.flatten(), hidden_dtype=hidden_dtype
)
temb_ca_scale_shift = temb_ca_scale_shift.view(
batch_size, -1, temb_ca_scale_shift.shape[-1]
)
temb_ca_gate, _ = self.av_ca_a2v_gate_adaln_single(
timestep.flatten() * ts_ca_mult
timestep.flatten() * self.av_ca_timestep_scale_multiplier,
hidden_dtype=hidden_dtype,
)
temb_ca_gate = temb_ca_gate.view(batch_size, -1, temb_ca_gate.shape[-1])
temb_ca_audio_scale_shift, _ = self.av_ca_audio_scale_shift_adaln_single(
audio_timestep.flatten()
audio_timestep.flatten(), hidden_dtype=audio_hidden_states.dtype
)
temb_ca_audio_scale_shift = temb_ca_audio_scale_shift.view(
batch_size, -1, temb_ca_audio_scale_shift.shape[-1]
)
temb_ca_audio_gate, _ = self.av_ca_v2a_gate_adaln_single(
audio_timestep.flatten() * ts_ca_mult
audio_timestep.flatten() * self.av_ca_timestep_scale_multiplier,
hidden_dtype=audio_hidden_states.dtype,
)
temb_ca_audio_gate = temb_ca_audio_gate.view(
batch_size, -1, temb_ca_audio_gate.shape[-1]
@@ -1413,7 +1422,8 @@ class LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin):
device=hidden_states.device, dtype=hidden_states.dtype
) + embedded_timestep[:, :, None].to(dtype=hidden_states.dtype)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
with torch.autocast(device_type=hidden_states.device.type, enabled=False):
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states, _ = self.proj_out(hidden_states)
@@ -1425,7 +1435,8 @@ class LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin):
audio_scale_shift_values[:, :, 0],
audio_scale_shift_values[:, :, 1],
)
audio_hidden_states = self.audio_norm_out(audio_hidden_states)
with torch.autocast(device_type=audio_hidden_states.device.type, enabled=False):
audio_hidden_states = self.audio_norm_out(audio_hidden_states)
audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
audio_hidden_states, _ = self.audio_proj_out(audio_hidden_states)

View File

@@ -90,6 +90,11 @@ class Gemma3MLP(nn.Module):
return x
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class Gemma3Attention(nn.Module):
def __init__(
self,
@@ -170,6 +175,19 @@ class Gemma3Attention(nn.Module):
is_neox_style=True,
)
# NOTE(gmixiaojin): The shared RotaryEmbedding above computes inv_freq on
# GPU and uses the x1*cos - x2*sin formula, which causes slight
# numerical differences vs HuggingFace (see the NOTE in
# rotary_embedding.py:_compute_inv_freq). For HF-exact alignment we
# precompute inv_freq on CPU and use rotate_half in self.rotary_emb().
freq_indices = (
torch.arange(0, self.head_dim, 2, dtype=torch.int64).float() / self.head_dim
)
inv_freq = 1.0 / (self.rope_theta**freq_indices)
if rope_scaling and rope_scaling.get("factor"):
inv_freq = inv_freq / float(rope_scaling["factor"])
self.register_buffer("_hf_inv_freq", inv_freq, persistent=False)
# Local Attention not support attention mask, we use global attention instead.
# self.attn = LocalAttention(
# self.num_heads,
@@ -189,6 +207,23 @@ class Gemma3Attention(nn.Module):
dim=self.head_dim, eps=config.text_config.rms_norm_eps
)
def rotary_emb(self, positions, q, k):
"""Apply RoPE using HF-exact formula with precomputed inv_freq."""
positions_flat = positions.flatten().float()
num_tokens = positions_flat.shape[0]
with torch.autocast(device_type=q.device.type, enabled=False):
freqs = torch.outer(positions_flat, self._hf_inv_freq.float())
emb = freqs.repeat(1, 2)
cos = emb.cos().to(q.dtype).unsqueeze(1)
sin = emb.sin().to(q.dtype).unsqueeze(1)
q = q.reshape(num_tokens, -1, self.head_dim)
k = k.reshape(num_tokens, -1, self.head_dim)
q = q * cos + _rotate_half(q) * sin
k = k * cos + _rotate_half(k) * sin
return q, k
def forward(
self,
positions: torch.Tensor,
@@ -209,16 +244,19 @@ class Gemma3Attention(nn.Module):
# Apply RoPE
q, k = self.rotary_emb(positions, q, k)
q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# TODO(FlamingoPg): Support LocalAttention
query = q.transpose(1, 2)
key = k.transpose(1, 2)
value = v.transpose(1, 2)
min_val = torch.finfo(query.dtype).min
attn_mask = torch.zeros(
(seq_len, seq_len),
device=hidden_states.device,
dtype=torch.float32,
dtype=query.dtype,
)
causal = torch.triu(
torch.ones(
@@ -226,18 +264,18 @@ class Gemma3Attention(nn.Module):
),
diagonal=1,
)
attn_mask = attn_mask.masked_fill(causal, float("-inf"))
attn_mask = attn_mask.masked_fill(causal, min_val)
if self.is_sliding and self.sliding_window is not None:
idx = torch.arange(seq_len, device=hidden_states.device)
dist = idx[None, :] - idx[:, None]
too_far = dist > self.sliding_window
attn_mask = attn_mask.masked_fill(too_far, float("-inf"))
attn_mask = attn_mask.masked_fill(too_far, min_val)
key_pad = ~attention_mask.to(torch.bool)
attn_mask = attn_mask[None, None, :, :].expand(batch_size, 1, seq_len, seq_len)
attn_mask = attn_mask.masked_fill(
key_pad[:, None, None, :].expand(batch_size, 1, seq_len, seq_len),
float("-inf"),
min_val,
)
attn_kwargs = {
@@ -707,7 +745,8 @@ class Gemma3TextModel(nn.Module):
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.embed_scale
out = self.embed_tokens(input_ids)
return out * torch.tensor(self.embed_scale, device=out.device, dtype=out.dtype)
def forward(
self,
@@ -735,7 +774,6 @@ class Gemma3TextModel(nn.Module):
position_ids = torch.arange(
0, hidden_states.shape[1], device=hidden_states.device
).unsqueeze(0)
position_ids = position_ids + 1
all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None

View File

@@ -1,7 +1,12 @@
import inspect
import json
import math
import os
import numpy as np
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (
ComposedPipelineBase,
)
@@ -50,14 +55,9 @@ def prepare_mu(batch: Req, server_args: ServerArgs):
vae_arch, "temporal_compression_ratio", None
) or getattr(server_args.pipeline_config, "vae_temporal_compression", None)
latent_num_frames = (int(num_frames) - 1) // int(vae_temporal_compression) + 1
latent_height = int(height) // int(vae_scale_factor)
latent_width = int(width) // int(vae_scale_factor)
video_sequence_length = latent_num_frames * latent_height * latent_width
# Values from LTX2Pipeline in diffusers
mu = calculate_shift(
video_sequence_length,
4096,
base_seq_len=1024,
max_seq_len=4096,
base_shift=0.95,
@@ -101,6 +101,17 @@ def _filter_kwargs_for_cls(cls, kwargs):
return {k: v for k, v in kwargs.items() if k in sig.parameters}
class LTX2FlowMatchScheduler(FlowMatchEulerDiscreteScheduler):
"""Override ``_time_shift_exponential`` to use torch f32 instead of numpy f64."""
def _time_shift_exponential(self, mu, sigma, t):
if isinstance(t, np.ndarray):
t_torch = torch.from_numpy(t).to(torch.float32)
result = math.exp(mu) / (math.exp(mu) + (1 / t_torch - 1) ** sigma)
return result.numpy()
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
class LTX2Pipeline(ComposedPipelineBase):
# NOTE: must match `model_index.json`'s `_class_name` for native dispatch.
pipeline_name = "LTX2Pipeline"
@@ -116,6 +127,10 @@ class LTX2Pipeline(ComposedPipelineBase):
"connectors",
]
def initialize_pipeline(self, server_args: ServerArgs):
orig = self.get_module("scheduler")
self.modules["scheduler"] = LTX2FlowMatchScheduler.from_config(orig.config)
def create_pipeline_stages(self, server_args: ServerArgs):
self.add_stages(
[

View File

@@ -37,7 +37,12 @@ class LTX2AVDecodingStage(DecodingStage):
vae_dtype != torch.float32
) and not server_args.disable_autocast
latents = self.scale_and_shift(latents, server_args)
original_dtype = vae_dtype
self.vae.to(torch.bfloat16)
latents = latents.to(torch.bfloat16)
std = self.vae.latents_std.view(1, -1, 1, 1, 1).to(latents)
mean = self.vae.latents_mean.view(1, -1, 1, 1, 1).to(latents)
latents = latents * std + mean
latents = server_args.pipeline_config.preprocess_decoding(
latents, server_args, vae=self.vae
)
@@ -52,8 +57,6 @@ class LTX2AVDecodingStage(DecodingStage):
self.vae.enable_tiling()
except Exception:
pass
if not vae_autocast_enabled:
latents = latents.to(vae_dtype)
decode_output = self.vae.decode(latents)
if isinstance(decode_output, tuple):
video = decode_output[0]
@@ -62,6 +65,7 @@ class LTX2AVDecodingStage(DecodingStage):
else:
video = decode_output
self.vae.to(original_dtype)
video = self.video_processor.postprocess_video(video, output_type="np")
output_batch = OutputBatch(

View File

@@ -1,6 +1,10 @@
import copy
import math
import time
from io import BytesIO
import av
import numpy as np
import PIL.Image
import torch
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
@@ -107,6 +111,66 @@ class LTX2AVDenoisingStage(DenoisingStage):
) -> PIL.Image.Image:
return img.resize((width, height), resample=PIL.Image.Resampling.BILINEAR)
@staticmethod
def _apply_video_codec_compression(
img_array: np.ndarray, crf: int = 33
) -> np.ndarray:
"""Encode as a single H.264 frame and decode back to simulate compression artifacts."""
if crf == 0:
return img_array
height, width = img_array.shape[0] // 2 * 2, img_array.shape[1] // 2 * 2
img_array = img_array[:height, :width]
buffer = BytesIO()
container = av.open(buffer, mode="w", format="mp4")
stream = container.add_stream(
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
)
stream.height, stream.width = height, width
frame = av.VideoFrame.from_ndarray(img_array, format="rgb24").reformat(
format="yuv420p"
)
container.mux(stream.encode(frame))
container.mux(stream.encode())
container.close()
buffer.seek(0)
container = av.open(buffer)
decoded = next(container.decode(container.streams.video[0]))
container.close()
return decoded.to_ndarray(format="rgb24")
@staticmethod
def _resize_center_crop_tensor(
img: PIL.Image.Image,
*,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
apply_codec_compression: bool = True,
codec_crf: int = 33,
) -> torch.Tensor:
"""Resize, center-crop, and normalize to [1, C, 1, H, W] tensor in [-1, 1]."""
img_array = np.array(img).astype(np.uint8)[..., :3]
if apply_codec_compression:
img_array = LTX2AVDenoisingStage._apply_video_codec_compression(
img_array, crf=codec_crf
)
tensor = (
torch.from_numpy(img_array.astype(np.float32))
.permute(2, 0, 1)
.unsqueeze(0)
.to(device=device)
)
src_h, src_w = tensor.shape[2], tensor.shape[3]
scale = max(height / src_h, width / src_w)
new_h, new_w = math.ceil(src_h * scale), math.ceil(src_w * scale)
tensor = torch.nn.functional.interpolate(
tensor, size=(new_h, new_w), mode="bilinear", align_corners=False
)
top, left = (new_h - height) // 2, (new_w - width) // 2
tensor = tensor[:, :, top : top + height, left : left + width]
return ((tensor / 127.5 - 1.0).to(dtype=dtype)).unsqueeze(2)
@staticmethod
def _pil_to_normed_tensor(img: PIL.Image.Image) -> torch.Tensor:
# PIL -> numpy [0,1] -> torch [B,C,H,W], then [-1,1]
@@ -155,31 +219,33 @@ class LTX2AVDenoisingStage(DenoisingStage):
)
img = load_image(image_path)
img = self._resize_center_crop(
batch.condition_image = self._resize_center_crop(
img, width=int(batch.width), height=int(batch.height)
)
batch.condition_image = img
latents_device = (
batch.latents.device
if isinstance(batch.latents, torch.Tensor)
else torch.device("cpu")
)
image_tensor = self._pil_to_normed_tensor(img).to(
latents_device, dtype=torch.float32
)
# [B, C, H, W] -> [B, C, 1, H, W]
video_condition = image_tensor.unsqueeze(2)
self.vae = self.vae.to(latents_device)
vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]
encode_dtype = batch.latents.dtype
original_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]
self.vae = self.vae.to(device=latents_device, dtype=encode_dtype)
vae_autocast_enabled = (
vae_dtype != torch.float32
original_dtype != torch.float32
) and not server_args.disable_autocast
video_condition = self._resize_center_crop_tensor(
img,
width=int(batch.width),
height=int(batch.height),
device=latents_device,
dtype=encode_dtype,
)
with torch.autocast(
device_type=current_platform.device_type,
dtype=vae_dtype,
dtype=original_dtype,
enabled=vae_autocast_enabled,
):
try:
@@ -188,7 +254,7 @@ class LTX2AVDenoisingStage(DenoisingStage):
except Exception:
pass
if not vae_autocast_enabled:
video_condition = video_condition.to(vae_dtype)
video_condition = video_condition.to(encode_dtype)
latent_dist: DiagonalGaussianDistribution = self.vae.encode(video_condition)
if isinstance(latent_dist, AutoencoderKLOutput):
@@ -204,19 +270,10 @@ class LTX2AVDenoisingStage(DenoisingStage):
else:
raise ValueError(f"Unsupported encode_sample_mode: {mode}")
# Match the normalized latent space used by this pipeline (inverse of DecodingStage.scale_and_shift).
scaling_factor, shift_factor = (
server_args.pipeline_config.get_decode_scale_and_shift(
device=latent.device, dtype=latent.dtype, vae=self.vae
)
)
if isinstance(shift_factor, torch.Tensor):
shift_factor = shift_factor.to(latent.device)
if isinstance(scaling_factor, torch.Tensor):
scaling_factor = scaling_factor.to(latent.device)
if shift_factor is not None:
latent = latent - shift_factor
latent = latent * scaling_factor
# Per-channel normalization: normalized = (x - mean) / std
mean = self.vae.latents_mean.view(1, -1, 1, 1, 1).to(latent)
std = self.vae.latents_std.view(1, -1, 1, 1, 1).to(latent)
latent = (latent - mean) / std
packed = server_args.pipeline_config.maybe_pack_latents(
latent, latent.shape[0], batch
@@ -248,6 +305,7 @@ class LTX2AVDenoisingStage(DenoisingStage):
batch.height,
)
self.vae.to(original_dtype)
if server_args.vae_cpu_offload:
self.vae = self.vae.to("cpu")
@@ -481,18 +539,24 @@ class LTX2AVDenoisingStage(DenoisingStage):
# Velocity -> denoised (x0): x0 = x - sigma * v
sigma_val = float(sigma.item())
denoised_video = latents.float() - sigma_val * v_pos
denoised_audio = audio_latents.float() - sigma_val * a_v_pos
denoised_video = (latents.float() - sigma_val * v_pos).to(
latents.dtype
)
denoised_audio = (
audio_latents.float() - sigma_val * a_v_pos
).to(audio_latents.dtype)
if (
batch.do_classifier_free_guidance
and v_neg is not None
and a_v_neg is not None
):
denoised_video_neg = latents.float() - sigma_val * v_neg
denoised_video_neg = (
latents.float() - sigma_val * v_neg
).to(latents.dtype)
denoised_audio_neg = (
audio_latents.float() - sigma_val * a_v_neg
)
).to(audio_latents.dtype)
denoised_video = denoised_video + (
batch.guidance_scale - 1.0
) * (denoised_video - denoised_video_neg)
@@ -517,17 +581,20 @@ class LTX2AVDenoisingStage(DenoisingStage):
v_video = torch.zeros_like(denoised_video)
v_audio = torch.zeros_like(denoised_audio)
else:
v_video = (latents.float() - denoised_video) / sigma_val
v_video = (
(latents.float() - denoised_video.float()) / sigma_val
).to(latents.dtype)
v_audio = (
audio_latents.float() - denoised_audio
) / sigma_val
(audio_latents.float() - denoised_audio.float())
/ sigma_val
).to(audio_latents.dtype)
latents = (latents.float() + v_video * dt).to(
latents = (latents.float() + v_video.float() * dt).to(
dtype=latents.dtype
)
audio_latents = (audio_latents.float() + v_audio * dt).to(
dtype=audio_latents.dtype
)
audio_latents = (
audio_latents.float() + v_audio.float() * dt
).to(dtype=audio_latents.dtype)
if do_ti2v:
latents[:, :num_img_tokens, :] = batch.image_latent[