Shrink text embeds to max token length for LTX-2. Drastically reduces cached text embedding sizes

This commit is contained in:
Jaret Burkett
2026-01-28 12:54:49 -07:00
parent ea912d2d7b
commit 1ce2428722
7 changed files with 130 additions and 27 deletions

View File

@@ -213,6 +213,9 @@ class LTX2Model(BaseModel):
# use the new format on this new model by default
self.use_old_lokr_format = False
self.audio_processor = None
# gemma needs left side padding
self.te_padding_side = "left"
# static method to get the noise scheduler
@staticmethod
@@ -627,6 +630,10 @@ class LTX2Model(BaseModel):
tile_sample_stride_width=224,
tile_sample_stride_num_frames=4,
)
# We only encode and store the minimum prompt tokens, but need them padded to 1024 for LTX2
conditional_embeds = self.pad_embeds(conditional_embeds)
unconditional_embeds = self.pad_embeds(unconditional_embeds)
video, audio = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
@@ -731,6 +738,29 @@ class LTX2Model(BaseModel):
latents_std = self.pipeline.audio_vae.latents_std
output_tensor = (output_tensor - latents_mean) / latents_std
return output_tensor
def pad_embeds(self, embeds: PromptEmbeds):
# ltx-2 connector requires 1024 tokens for good results. Any smaller and it degrades.
target_length = 1024
current_length = embeds.text_embeds.shape[1]
if current_length < target_length:
pad_length = target_length - current_length
pad_tensor = torch.zeros(
(embeds.text_embeds.shape[0], pad_length, embeds.text_embeds.shape[2]),
device=embeds.text_embeds.device,
dtype=embeds.text_embeds.dtype,
)
embeds.text_embeds = torch.cat([pad_tensor, embeds.text_embeds], dim=1)
if embeds.attention_mask is not None:
pad_mask = torch.zeros(
(embeds.attention_mask.shape[0], pad_length),
device=embeds.attention_mask.device,
dtype=embeds.attention_mask.dtype,
)
embeds.attention_mask = torch.cat(
[pad_mask, embeds.attention_mask], dim=1
)
return embeds
def get_noise_prediction(
self,
@@ -743,6 +773,9 @@ class LTX2Model(BaseModel):
with torch.no_grad():
if self.model.device == torch.device("cpu"):
self.model.to(self.device_torch)
# We only encode and store the minimum prompt tokens, but need them padded to 1024 for LTX2
text_embeddings = self.pad_embeds(text_embeddings)
batch_size, C, latent_num_frames, latent_height, latent_width = (
latent_model_input.shape
@@ -916,11 +949,58 @@ class LTX2Model(BaseModel):
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
device = self.device_torch
scale_factor = 8
batch_size = len(prompt)
# Gemma expects left padding for chat-style prompts
self.tokenizer[0].padding_side = "left"
if self.tokenizer[0].pad_token is None:
self.tokenizer[0].pad_token = self.tokenizer[0].eos_token
prompt = [p.strip() for p in prompt]
text_inputs = self.tokenizer[0](
prompt,
do_classifier_free_guidance=False,
device=self.device_torch,
# padding="max_length",
padding="longest",
max_length=1024,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
text_input_ids = text_input_ids.to(device)
prompt_attention_mask = prompt_attention_mask.to(device)
text_encoder_outputs = self.text_encoder[0](
input_ids=text_input_ids,
attention_mask=prompt_attention_mask,
output_hidden_states=True,
)
text_encoder_hidden_states = text_encoder_outputs.hidden_states
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
sequence_lengths = prompt_attention_mask.sum(dim=-1)
prompt_embeds = self.pipeline._pack_text_embeds(
text_encoder_hidden_states,
sequence_lengths,
device=device,
padding_side=self.tokenizer[0].padding_side,
scale_factor=scale_factor,
)
prompt_embeds = prompt_embeds.to(dtype=self.torch_dtype)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
prompt_embeds = prompt_embeds.view(
batch_size * 1, seq_len, -1
)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(1, 1)
pe = PromptEmbeds([prompt_embeds, None])
pe.attention_mask = prompt_attention_mask
return pe