mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-04-20 14:29:32 +00:00
[diffusion] model: LTX-2 Support PR3 (#19151)
This commit is contained in:
@@ -19,7 +19,7 @@ dependencies = [
|
||||
"aiohttp",
|
||||
"apache-tvm-ffi>=0.1.5,<0.2",
|
||||
"anthropic>=0.20.0",
|
||||
"av ; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'armv7l')",
|
||||
"av",
|
||||
"blobfile==3.0.0",
|
||||
"build",
|
||||
"compressed-tensors",
|
||||
|
||||
@@ -19,6 +19,7 @@ dependencies = [
|
||||
"aiohttp",
|
||||
"anthropic>=0.20.0",
|
||||
"blobfile==3.0.0",
|
||||
"av",
|
||||
"build",
|
||||
"compressed-tensors",
|
||||
"decord2",
|
||||
|
||||
@@ -21,6 +21,7 @@ runtime_common = [
|
||||
"aiohttp",
|
||||
"anthropic>=0.20.0",
|
||||
"blobfile==3.0.0",
|
||||
"av",
|
||||
"build",
|
||||
"compressed-tensors",
|
||||
"decord2",
|
||||
|
||||
@@ -164,7 +164,7 @@ class LTX2ArchConfig(DiTArchConfig):
|
||||
self.audio_num_attention_heads * self.audio_attention_head_dim
|
||||
)
|
||||
if self.audio_positional_embedding_max_pos is None:
|
||||
self.audio_positional_embedding_max_pos = [2048]
|
||||
self.audio_positional_embedding_max_pos = [20]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -2,7 +2,6 @@ import dataclasses
|
||||
from dataclasses import field
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.multimodal_gen.configs.models.dits.ltx_2 import LTX2Config
|
||||
@@ -139,6 +138,23 @@ class LTX2PipelineConfig(PipelineConfig):
|
||||
def vae_temporal_compression(self):
|
||||
return getattr(self.vae_config.arch_config, "temporal_compression_ratio", 8)
|
||||
|
||||
def prepare_latent_shape(self, batch, batch_size, num_frames):
|
||||
"""Return packed latent shape [B, seq, C] directly."""
|
||||
height = batch.height // self.vae_scale_factor
|
||||
width = batch.width // self.vae_scale_factor
|
||||
|
||||
post_patch_num_frames = num_frames // self.patch_size_t
|
||||
post_patch_height = height // self.patch_size
|
||||
post_patch_width = width // self.patch_size
|
||||
seq_len = post_patch_num_frames * post_patch_height * post_patch_width
|
||||
|
||||
num_channels = (
|
||||
self.in_channels * self.patch_size_t * self.patch_size * self.patch_size
|
||||
)
|
||||
|
||||
shape = (batch_size, seq_len, num_channels)
|
||||
return shape
|
||||
|
||||
def prepare_audio_latent_shape(self, batch, batch_size, num_frames):
|
||||
# Adapted from diffusers pipeline prepare_audio_latents
|
||||
duration_s = num_frames / batch.fps
|
||||
@@ -159,7 +175,7 @@ class LTX2PipelineConfig(PipelineConfig):
|
||||
# Default to 8
|
||||
num_channels_latents = self.audio_vae_config.arch_config.latent_channels
|
||||
|
||||
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
|
||||
shape = (batch_size, latent_length, num_channels_latents * latent_mel_bins)
|
||||
|
||||
return shape
|
||||
|
||||
@@ -184,7 +200,7 @@ class LTX2PipelineConfig(PipelineConfig):
|
||||
steps = int(num_inference_steps)
|
||||
if steps <= 0:
|
||||
raise ValueError(f"num_inference_steps must be positive, got {steps}")
|
||||
return np.linspace(1.0, 1.0 / float(steps), steps).tolist()
|
||||
return [1.0 - i / steps for i in range(steps)]
|
||||
return sigmas
|
||||
|
||||
def tokenize_prompt(self, prompt: list[str], tokenizer, tok_kwargs) -> dict:
|
||||
@@ -210,6 +226,10 @@ class LTX2PipelineConfig(PipelineConfig):
|
||||
return text_inputs
|
||||
|
||||
def maybe_pack_latents(self, latents, batch_size, batch):
|
||||
# If already packed (3D shape [B, seq, C]), skip packing
|
||||
if latents.dim() == 3:
|
||||
return latents
|
||||
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
|
||||
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
|
||||
@@ -338,6 +358,10 @@ class LTX2PipelineConfig(PipelineConfig):
|
||||
return super().gather_latents_for_sp(latents)
|
||||
|
||||
def maybe_pack_audio_latents(self, latents, batch_size, batch):
|
||||
# If already packed (3D shape [B, T, C*F]), skip packing
|
||||
if latents.dim() == 3:
|
||||
return latents
|
||||
|
||||
# Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins
|
||||
# We need to pack them if patch_size/patch_size_t are defined for audio (not standard DiT patch size)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def apply_split_rotary_emb(
|
||||
# The cos/sin batch dim may only be broadcastable, so take batch size from x
|
||||
b = x.shape[0]
|
||||
_, h, t, _ = cos.shape
|
||||
x = x.reshape(b, t, h, -1).swapaxes(1, 2)
|
||||
x = x.reshape(b, t, h, -1).transpose(1, 2)
|
||||
needs_reshape = True
|
||||
|
||||
# Split last dim (2*r) into (d=2, r)
|
||||
@@ -46,7 +46,7 @@ def apply_split_rotary_emb(
|
||||
r = last // 2
|
||||
|
||||
# (..., 2, r)
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r)
|
||||
first_x = split_x[..., :1, :] # (..., 1, r)
|
||||
second_x = split_x[..., 1:, :] # (..., 1, r)
|
||||
|
||||
@@ -63,7 +63,7 @@ def apply_split_rotary_emb(
|
||||
out = out.reshape(*out.shape[:-2], last)
|
||||
|
||||
if needs_reshape:
|
||||
out = out.swapaxes(1, 2).reshape(b, t, -1)
|
||||
out = out.transpose(1, 2).reshape(b, t, -1)
|
||||
|
||||
out = out.to(dtype=x_dtype)
|
||||
return out
|
||||
@@ -232,6 +232,7 @@ class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
batch_size: int,
|
||||
pos: int,
|
||||
device: Union[str, torch.device],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Get 1D position ids
|
||||
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
|
||||
@@ -297,6 +298,9 @@ class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||
sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||
|
||||
if dtype is not None:
|
||||
cos_freqs = cos_freqs.to(dtype)
|
||||
sin_freqs = sin_freqs.to(dtype)
|
||||
return cos_freqs, sin_freqs
|
||||
|
||||
|
||||
@@ -460,7 +464,9 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
attention_mask = torch.zeros_like(attention_mask)
|
||||
|
||||
# 2. Calculate 1D RoPE positional embeddings
|
||||
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
|
||||
rotary_emb = self.rope(
|
||||
batch_size, seq_len, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 3. Run 1D transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
|
||||
@@ -66,7 +66,7 @@ def apply_split_rotary_emb(
|
||||
)
|
||||
r = last // 2
|
||||
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r).float()
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r)
|
||||
first_x = split_x[..., :1, :]
|
||||
second_x = split_x[..., 1:, :]
|
||||
|
||||
@@ -137,6 +137,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
self.causal_offset = int(causal_offset)
|
||||
|
||||
self.modality = modality
|
||||
self.coords_dtype = torch.bfloat16 if modality == "video" else torch.float32
|
||||
if self.modality not in ["video", "audio"]:
|
||||
raise ValueError(
|
||||
f"Modality {modality} is not supported. Supported modalities are `video` and `audio`."
|
||||
@@ -243,6 +244,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
device = device or coords.device
|
||||
num_pos_dims = coords.shape[1]
|
||||
|
||||
coords = coords.to(self.coords_dtype)
|
||||
if coords.ndim == 4:
|
||||
coords_start, coords_end = coords.chunk(2, dim=-1)
|
||||
coords = (coords_start + coords_end) / 2.0
|
||||
@@ -307,7 +309,9 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
cos_freqs = torch.swapaxes(cos_freq, 1, 2)
|
||||
sin_freqs = torch.swapaxes(sin_freq, 1, 2)
|
||||
|
||||
return cos_freqs, sin_freqs
|
||||
# Cast to bf16 to match model weights dtype. coords_dtype controls
|
||||
# intermediate coordinate precision (fp32 for audio) and differs.
|
||||
return cos_freqs.to(torch.bfloat16), sin_freqs.to(torch.bfloat16)
|
||||
|
||||
|
||||
def rms_norm(x: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
@@ -1121,7 +1125,9 @@ class LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin):
|
||||
if hasattr(arch.rope_type, "value")
|
||||
else str(arch.rope_type)
|
||||
)
|
||||
rope_double_precision = bool(getattr(arch, "double_precision_rope", True))
|
||||
rope_double_precision = bool(
|
||||
hf_config.get("rope_double_precision", arch.double_precision_rope)
|
||||
)
|
||||
causal_offset = int(hf_config.get("causal_offset", 1))
|
||||
|
||||
pos_embed_max_pos = int(arch.positional_embedding_max_pos[0])
|
||||
@@ -1351,27 +1357,30 @@ class LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin):
|
||||
self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||
)
|
||||
|
||||
hidden_dtype = hidden_states.dtype
|
||||
temb_ca_scale_shift, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep.flatten()
|
||||
timestep.flatten(), hidden_dtype=hidden_dtype
|
||||
)
|
||||
temb_ca_scale_shift = temb_ca_scale_shift.view(
|
||||
batch_size, -1, temb_ca_scale_shift.shape[-1]
|
||||
)
|
||||
|
||||
temb_ca_gate, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep.flatten() * ts_ca_mult
|
||||
timestep.flatten() * self.av_ca_timestep_scale_multiplier,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
temb_ca_gate = temb_ca_gate.view(batch_size, -1, temb_ca_gate.shape[-1])
|
||||
|
||||
temb_ca_audio_scale_shift, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
audio_timestep.flatten()
|
||||
audio_timestep.flatten(), hidden_dtype=audio_hidden_states.dtype
|
||||
)
|
||||
temb_ca_audio_scale_shift = temb_ca_audio_scale_shift.view(
|
||||
batch_size, -1, temb_ca_audio_scale_shift.shape[-1]
|
||||
)
|
||||
|
||||
temb_ca_audio_gate, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
audio_timestep.flatten() * ts_ca_mult
|
||||
audio_timestep.flatten() * self.av_ca_timestep_scale_multiplier,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
temb_ca_audio_gate = temb_ca_audio_gate.view(
|
||||
batch_size, -1, temb_ca_audio_gate.shape[-1]
|
||||
@@ -1413,7 +1422,8 @@ class LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin):
|
||||
device=hidden_states.device, dtype=hidden_states.dtype
|
||||
) + embedded_timestep[:, :, None].to(dtype=hidden_states.dtype)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
with torch.autocast(device_type=hidden_states.device.type, enabled=False):
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states, _ = self.proj_out(hidden_states)
|
||||
|
||||
@@ -1425,7 +1435,8 @@ class LTX2VideoTransformer3DModel(CachableDiT, OffloadableDiTMixin):
|
||||
audio_scale_shift_values[:, :, 0],
|
||||
audio_scale_shift_values[:, :, 1],
|
||||
)
|
||||
audio_hidden_states = self.audio_norm_out(audio_hidden_states)
|
||||
with torch.autocast(device_type=audio_hidden_states.device.type, enabled=False):
|
||||
audio_hidden_states = self.audio_norm_out(audio_hidden_states)
|
||||
audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
|
||||
audio_hidden_states, _ = self.audio_proj_out(audio_hidden_states)
|
||||
|
||||
|
||||
@@ -90,6 +90,11 @@ class Gemma3MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
class Gemma3Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -170,6 +175,19 @@ class Gemma3Attention(nn.Module):
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
# NOTE(gmixiaojin): The shared RotaryEmbedding above computes inv_freq on
|
||||
# GPU and uses the x1*cos - x2*sin formula, which causes slight
|
||||
# numerical differences vs HuggingFace (see the NOTE in
|
||||
# rotary_embedding.py:_compute_inv_freq). For HF-exact alignment we
|
||||
# precompute inv_freq on CPU and use rotate_half in self.rotary_emb().
|
||||
freq_indices = (
|
||||
torch.arange(0, self.head_dim, 2, dtype=torch.int64).float() / self.head_dim
|
||||
)
|
||||
inv_freq = 1.0 / (self.rope_theta**freq_indices)
|
||||
if rope_scaling and rope_scaling.get("factor"):
|
||||
inv_freq = inv_freq / float(rope_scaling["factor"])
|
||||
self.register_buffer("_hf_inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Local Attention not support attention mask, we use global attention instead.
|
||||
# self.attn = LocalAttention(
|
||||
# self.num_heads,
|
||||
@@ -189,6 +207,23 @@ class Gemma3Attention(nn.Module):
|
||||
dim=self.head_dim, eps=config.text_config.rms_norm_eps
|
||||
)
|
||||
|
||||
def rotary_emb(self, positions, q, k):
|
||||
"""Apply RoPE using HF-exact formula with precomputed inv_freq."""
|
||||
positions_flat = positions.flatten().float()
|
||||
num_tokens = positions_flat.shape[0]
|
||||
|
||||
with torch.autocast(device_type=q.device.type, enabled=False):
|
||||
freqs = torch.outer(positions_flat, self._hf_inv_freq.float())
|
||||
emb = freqs.repeat(1, 2)
|
||||
cos = emb.cos().to(q.dtype).unsqueeze(1)
|
||||
sin = emb.sin().to(q.dtype).unsqueeze(1)
|
||||
|
||||
q = q.reshape(num_tokens, -1, self.head_dim)
|
||||
k = k.reshape(num_tokens, -1, self.head_dim)
|
||||
q = q * cos + _rotate_half(q) * sin
|
||||
k = k * cos + _rotate_half(k) * sin
|
||||
return q, k
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -209,16 +244,19 @@ class Gemma3Attention(nn.Module):
|
||||
|
||||
# Apply RoPE
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# TODO(FlamingoPg): Support LocalAttention
|
||||
query = q.transpose(1, 2)
|
||||
key = k.transpose(1, 2)
|
||||
value = v.transpose(1, 2)
|
||||
|
||||
min_val = torch.finfo(query.dtype).min
|
||||
attn_mask = torch.zeros(
|
||||
(seq_len, seq_len),
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
dtype=query.dtype,
|
||||
)
|
||||
causal = torch.triu(
|
||||
torch.ones(
|
||||
@@ -226,18 +264,18 @@ class Gemma3Attention(nn.Module):
|
||||
),
|
||||
diagonal=1,
|
||||
)
|
||||
attn_mask = attn_mask.masked_fill(causal, float("-inf"))
|
||||
attn_mask = attn_mask.masked_fill(causal, min_val)
|
||||
if self.is_sliding and self.sliding_window is not None:
|
||||
idx = torch.arange(seq_len, device=hidden_states.device)
|
||||
dist = idx[None, :] - idx[:, None]
|
||||
too_far = dist > self.sliding_window
|
||||
attn_mask = attn_mask.masked_fill(too_far, float("-inf"))
|
||||
attn_mask = attn_mask.masked_fill(too_far, min_val)
|
||||
|
||||
key_pad = ~attention_mask.to(torch.bool)
|
||||
attn_mask = attn_mask[None, None, :, :].expand(batch_size, 1, seq_len, seq_len)
|
||||
attn_mask = attn_mask.masked_fill(
|
||||
key_pad[:, None, None, :].expand(batch_size, 1, seq_len, seq_len),
|
||||
float("-inf"),
|
||||
min_val,
|
||||
)
|
||||
|
||||
attn_kwargs = {
|
||||
@@ -707,7 +745,8 @@ class Gemma3TextModel(nn.Module):
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.embed_scale
|
||||
out = self.embed_tokens(input_ids)
|
||||
return out * torch.tensor(self.embed_scale, device=out.device, dtype=out.dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -735,7 +774,6 @@ class Gemma3TextModel(nn.Module):
|
||||
position_ids = torch.arange(
|
||||
0, hidden_states.shape[1], device=hidden_states.device
|
||||
).unsqueeze(0)
|
||||
position_ids = position_ids + 1
|
||||
|
||||
all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (
|
||||
ComposedPipelineBase,
|
||||
)
|
||||
@@ -50,14 +55,9 @@ def prepare_mu(batch: Req, server_args: ServerArgs):
|
||||
vae_arch, "temporal_compression_ratio", None
|
||||
) or getattr(server_args.pipeline_config, "vae_temporal_compression", None)
|
||||
|
||||
latent_num_frames = (int(num_frames) - 1) // int(vae_temporal_compression) + 1
|
||||
latent_height = int(height) // int(vae_scale_factor)
|
||||
latent_width = int(width) // int(vae_scale_factor)
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
# Values from LTX2Pipeline in diffusers
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
4096,
|
||||
base_seq_len=1024,
|
||||
max_seq_len=4096,
|
||||
base_shift=0.95,
|
||||
@@ -101,6 +101,17 @@ def _filter_kwargs_for_cls(cls, kwargs):
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
class LTX2FlowMatchScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
"""Override ``_time_shift_exponential`` to use torch f32 instead of numpy f64."""
|
||||
|
||||
def _time_shift_exponential(self, mu, sigma, t):
|
||||
if isinstance(t, np.ndarray):
|
||||
t_torch = torch.from_numpy(t).to(torch.float32)
|
||||
result = math.exp(mu) / (math.exp(mu) + (1 / t_torch - 1) ** sigma)
|
||||
return result.numpy()
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
class LTX2Pipeline(ComposedPipelineBase):
|
||||
# NOTE: must match `model_index.json`'s `_class_name` for native dispatch.
|
||||
pipeline_name = "LTX2Pipeline"
|
||||
@@ -116,6 +127,10 @@ class LTX2Pipeline(ComposedPipelineBase):
|
||||
"connectors",
|
||||
]
|
||||
|
||||
def initialize_pipeline(self, server_args: ServerArgs):
|
||||
orig = self.get_module("scheduler")
|
||||
self.modules["scheduler"] = LTX2FlowMatchScheduler.from_config(orig.config)
|
||||
|
||||
def create_pipeline_stages(self, server_args: ServerArgs):
|
||||
self.add_stages(
|
||||
[
|
||||
|
||||
@@ -37,7 +37,12 @@ class LTX2AVDecodingStage(DecodingStage):
|
||||
vae_dtype != torch.float32
|
||||
) and not server_args.disable_autocast
|
||||
|
||||
latents = self.scale_and_shift(latents, server_args)
|
||||
original_dtype = vae_dtype
|
||||
self.vae.to(torch.bfloat16)
|
||||
latents = latents.to(torch.bfloat16)
|
||||
std = self.vae.latents_std.view(1, -1, 1, 1, 1).to(latents)
|
||||
mean = self.vae.latents_mean.view(1, -1, 1, 1, 1).to(latents)
|
||||
latents = latents * std + mean
|
||||
latents = server_args.pipeline_config.preprocess_decoding(
|
||||
latents, server_args, vae=self.vae
|
||||
)
|
||||
@@ -52,8 +57,6 @@ class LTX2AVDecodingStage(DecodingStage):
|
||||
self.vae.enable_tiling()
|
||||
except Exception:
|
||||
pass
|
||||
if not vae_autocast_enabled:
|
||||
latents = latents.to(vae_dtype)
|
||||
decode_output = self.vae.decode(latents)
|
||||
if isinstance(decode_output, tuple):
|
||||
video = decode_output[0]
|
||||
@@ -62,6 +65,7 @@ class LTX2AVDecodingStage(DecodingStage):
|
||||
else:
|
||||
video = decode_output
|
||||
|
||||
self.vae.to(original_dtype)
|
||||
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||
|
||||
output_batch = OutputBatch(
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import copy
|
||||
import math
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
||||
@@ -107,6 +111,66 @@ class LTX2AVDenoisingStage(DenoisingStage):
|
||||
) -> PIL.Image.Image:
|
||||
return img.resize((width, height), resample=PIL.Image.Resampling.BILINEAR)
|
||||
|
||||
@staticmethod
|
||||
def _apply_video_codec_compression(
|
||||
img_array: np.ndarray, crf: int = 33
|
||||
) -> np.ndarray:
|
||||
"""Encode as a single H.264 frame and decode back to simulate compression artifacts."""
|
||||
if crf == 0:
|
||||
return img_array
|
||||
height, width = img_array.shape[0] // 2 * 2, img_array.shape[1] // 2 * 2
|
||||
img_array = img_array[:height, :width]
|
||||
buffer = BytesIO()
|
||||
container = av.open(buffer, mode="w", format="mp4")
|
||||
stream = container.add_stream(
|
||||
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
||||
)
|
||||
stream.height, stream.width = height, width
|
||||
frame = av.VideoFrame.from_ndarray(img_array, format="rgb24").reformat(
|
||||
format="yuv420p"
|
||||
)
|
||||
container.mux(stream.encode(frame))
|
||||
container.mux(stream.encode())
|
||||
container.close()
|
||||
buffer.seek(0)
|
||||
container = av.open(buffer)
|
||||
decoded = next(container.decode(container.streams.video[0]))
|
||||
container.close()
|
||||
return decoded.to_ndarray(format="rgb24")
|
||||
|
||||
@staticmethod
|
||||
def _resize_center_crop_tensor(
|
||||
img: PIL.Image.Image,
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
apply_codec_compression: bool = True,
|
||||
codec_crf: int = 33,
|
||||
) -> torch.Tensor:
|
||||
"""Resize, center-crop, and normalize to [1, C, 1, H, W] tensor in [-1, 1]."""
|
||||
img_array = np.array(img).astype(np.uint8)[..., :3]
|
||||
if apply_codec_compression:
|
||||
img_array = LTX2AVDenoisingStage._apply_video_codec_compression(
|
||||
img_array, crf=codec_crf
|
||||
)
|
||||
tensor = (
|
||||
torch.from_numpy(img_array.astype(np.float32))
|
||||
.permute(2, 0, 1)
|
||||
.unsqueeze(0)
|
||||
.to(device=device)
|
||||
)
|
||||
src_h, src_w = tensor.shape[2], tensor.shape[3]
|
||||
scale = max(height / src_h, width / src_w)
|
||||
new_h, new_w = math.ceil(src_h * scale), math.ceil(src_w * scale)
|
||||
tensor = torch.nn.functional.interpolate(
|
||||
tensor, size=(new_h, new_w), mode="bilinear", align_corners=False
|
||||
)
|
||||
top, left = (new_h - height) // 2, (new_w - width) // 2
|
||||
tensor = tensor[:, :, top : top + height, left : left + width]
|
||||
return ((tensor / 127.5 - 1.0).to(dtype=dtype)).unsqueeze(2)
|
||||
|
||||
@staticmethod
|
||||
def _pil_to_normed_tensor(img: PIL.Image.Image) -> torch.Tensor:
|
||||
# PIL -> numpy [0,1] -> torch [B,C,H,W], then [-1,1]
|
||||
@@ -155,31 +219,33 @@ class LTX2AVDenoisingStage(DenoisingStage):
|
||||
)
|
||||
|
||||
img = load_image(image_path)
|
||||
img = self._resize_center_crop(
|
||||
batch.condition_image = self._resize_center_crop(
|
||||
img, width=int(batch.width), height=int(batch.height)
|
||||
)
|
||||
batch.condition_image = img
|
||||
|
||||
latents_device = (
|
||||
batch.latents.device
|
||||
if isinstance(batch.latents, torch.Tensor)
|
||||
else torch.device("cpu")
|
||||
)
|
||||
image_tensor = self._pil_to_normed_tensor(img).to(
|
||||
latents_device, dtype=torch.float32
|
||||
)
|
||||
# [B, C, H, W] -> [B, C, 1, H, W]
|
||||
video_condition = image_tensor.unsqueeze(2)
|
||||
|
||||
self.vae = self.vae.to(latents_device)
|
||||
vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]
|
||||
encode_dtype = batch.latents.dtype
|
||||
original_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]
|
||||
self.vae = self.vae.to(device=latents_device, dtype=encode_dtype)
|
||||
vae_autocast_enabled = (
|
||||
vae_dtype != torch.float32
|
||||
original_dtype != torch.float32
|
||||
) and not server_args.disable_autocast
|
||||
|
||||
video_condition = self._resize_center_crop_tensor(
|
||||
img,
|
||||
width=int(batch.width),
|
||||
height=int(batch.height),
|
||||
device=latents_device,
|
||||
dtype=encode_dtype,
|
||||
)
|
||||
|
||||
with torch.autocast(
|
||||
device_type=current_platform.device_type,
|
||||
dtype=vae_dtype,
|
||||
dtype=original_dtype,
|
||||
enabled=vae_autocast_enabled,
|
||||
):
|
||||
try:
|
||||
@@ -188,7 +254,7 @@ class LTX2AVDenoisingStage(DenoisingStage):
|
||||
except Exception:
|
||||
pass
|
||||
if not vae_autocast_enabled:
|
||||
video_condition = video_condition.to(vae_dtype)
|
||||
video_condition = video_condition.to(encode_dtype)
|
||||
|
||||
latent_dist: DiagonalGaussianDistribution = self.vae.encode(video_condition)
|
||||
if isinstance(latent_dist, AutoencoderKLOutput):
|
||||
@@ -204,19 +270,10 @@ class LTX2AVDenoisingStage(DenoisingStage):
|
||||
else:
|
||||
raise ValueError(f"Unsupported encode_sample_mode: {mode}")
|
||||
|
||||
# Match the normalized latent space used by this pipeline (inverse of DecodingStage.scale_and_shift).
|
||||
scaling_factor, shift_factor = (
|
||||
server_args.pipeline_config.get_decode_scale_and_shift(
|
||||
device=latent.device, dtype=latent.dtype, vae=self.vae
|
||||
)
|
||||
)
|
||||
if isinstance(shift_factor, torch.Tensor):
|
||||
shift_factor = shift_factor.to(latent.device)
|
||||
if isinstance(scaling_factor, torch.Tensor):
|
||||
scaling_factor = scaling_factor.to(latent.device)
|
||||
if shift_factor is not None:
|
||||
latent = latent - shift_factor
|
||||
latent = latent * scaling_factor
|
||||
# Per-channel normalization: normalized = (x - mean) / std
|
||||
mean = self.vae.latents_mean.view(1, -1, 1, 1, 1).to(latent)
|
||||
std = self.vae.latents_std.view(1, -1, 1, 1, 1).to(latent)
|
||||
latent = (latent - mean) / std
|
||||
|
||||
packed = server_args.pipeline_config.maybe_pack_latents(
|
||||
latent, latent.shape[0], batch
|
||||
@@ -248,6 +305,7 @@ class LTX2AVDenoisingStage(DenoisingStage):
|
||||
batch.height,
|
||||
)
|
||||
|
||||
self.vae.to(original_dtype)
|
||||
if server_args.vae_cpu_offload:
|
||||
self.vae = self.vae.to("cpu")
|
||||
|
||||
@@ -481,18 +539,24 @@ class LTX2AVDenoisingStage(DenoisingStage):
|
||||
|
||||
# Velocity -> denoised (x0): x0 = x - sigma * v
|
||||
sigma_val = float(sigma.item())
|
||||
denoised_video = latents.float() - sigma_val * v_pos
|
||||
denoised_audio = audio_latents.float() - sigma_val * a_v_pos
|
||||
denoised_video = (latents.float() - sigma_val * v_pos).to(
|
||||
latents.dtype
|
||||
)
|
||||
denoised_audio = (
|
||||
audio_latents.float() - sigma_val * a_v_pos
|
||||
).to(audio_latents.dtype)
|
||||
|
||||
if (
|
||||
batch.do_classifier_free_guidance
|
||||
and v_neg is not None
|
||||
and a_v_neg is not None
|
||||
):
|
||||
denoised_video_neg = latents.float() - sigma_val * v_neg
|
||||
denoised_video_neg = (
|
||||
latents.float() - sigma_val * v_neg
|
||||
).to(latents.dtype)
|
||||
denoised_audio_neg = (
|
||||
audio_latents.float() - sigma_val * a_v_neg
|
||||
)
|
||||
).to(audio_latents.dtype)
|
||||
denoised_video = denoised_video + (
|
||||
batch.guidance_scale - 1.0
|
||||
) * (denoised_video - denoised_video_neg)
|
||||
@@ -517,17 +581,20 @@ class LTX2AVDenoisingStage(DenoisingStage):
|
||||
v_video = torch.zeros_like(denoised_video)
|
||||
v_audio = torch.zeros_like(denoised_audio)
|
||||
else:
|
||||
v_video = (latents.float() - denoised_video) / sigma_val
|
||||
v_video = (
|
||||
(latents.float() - denoised_video.float()) / sigma_val
|
||||
).to(latents.dtype)
|
||||
v_audio = (
|
||||
audio_latents.float() - denoised_audio
|
||||
) / sigma_val
|
||||
(audio_latents.float() - denoised_audio.float())
|
||||
/ sigma_val
|
||||
).to(audio_latents.dtype)
|
||||
|
||||
latents = (latents.float() + v_video * dt).to(
|
||||
latents = (latents.float() + v_video.float() * dt).to(
|
||||
dtype=latents.dtype
|
||||
)
|
||||
audio_latents = (audio_latents.float() + v_audio * dt).to(
|
||||
dtype=audio_latents.dtype
|
||||
)
|
||||
audio_latents = (
|
||||
audio_latents.float() + v_audio.float() * dt
|
||||
).to(dtype=audio_latents.dtype)
|
||||
|
||||
if do_ti2v:
|
||||
latents[:, :num_img_tokens, :] = batch.image_latent[
|
||||
|
||||
Reference in New Issue
Block a user