mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Shrink text embeds to max token length for LTX-2. Drastically reduces cached text embedding sizes
This commit is contained in:
@@ -213,6 +213,9 @@ class LTX2Model(BaseModel):
|
||||
# use the new format on this new model by default
|
||||
self.use_old_lokr_format = False
|
||||
self.audio_processor = None
|
||||
|
||||
# gemma needs left side padding
|
||||
self.te_padding_side = "left"
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@staticmethod
|
||||
@@ -627,6 +630,10 @@ class LTX2Model(BaseModel):
|
||||
tile_sample_stride_width=224,
|
||||
tile_sample_stride_num_frames=4,
|
||||
)
|
||||
|
||||
# We only encode and store the minimum prompt tokens, but need them padded to 1024 for LTX2
|
||||
conditional_embeds = self.pad_embeds(conditional_embeds)
|
||||
unconditional_embeds = self.pad_embeds(unconditional_embeds)
|
||||
|
||||
video, audio = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
@@ -731,6 +738,29 @@ class LTX2Model(BaseModel):
|
||||
latents_std = self.pipeline.audio_vae.latents_std
|
||||
output_tensor = (output_tensor - latents_mean) / latents_std
|
||||
return output_tensor
|
||||
|
||||
def pad_embeds(self, embeds: PromptEmbeds):
|
||||
# ltx-2 connector requires 1024 tokens for good results. Any smaller and it degrades.
|
||||
target_length = 1024
|
||||
current_length = embeds.text_embeds.shape[1]
|
||||
if current_length < target_length:
|
||||
pad_length = target_length - current_length
|
||||
pad_tensor = torch.zeros(
|
||||
(embeds.text_embeds.shape[0], pad_length, embeds.text_embeds.shape[2]),
|
||||
device=embeds.text_embeds.device,
|
||||
dtype=embeds.text_embeds.dtype,
|
||||
)
|
||||
embeds.text_embeds = torch.cat([pad_tensor, embeds.text_embeds], dim=1)
|
||||
if embeds.attention_mask is not None:
|
||||
pad_mask = torch.zeros(
|
||||
(embeds.attention_mask.shape[0], pad_length),
|
||||
device=embeds.attention_mask.device,
|
||||
dtype=embeds.attention_mask.dtype,
|
||||
)
|
||||
embeds.attention_mask = torch.cat(
|
||||
[pad_mask, embeds.attention_mask], dim=1
|
||||
)
|
||||
return embeds
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
@@ -743,6 +773,9 @@ class LTX2Model(BaseModel):
|
||||
with torch.no_grad():
|
||||
if self.model.device == torch.device("cpu"):
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
# We only encode and store the minimum prompt tokens, but need them padded to 1024 for LTX2
|
||||
text_embeddings = self.pad_embeds(text_embeddings)
|
||||
|
||||
batch_size, C, latent_num_frames, latent_height, latent_width = (
|
||||
latent_model_input.shape
|
||||
@@ -916,11 +949,58 @@ class LTX2Model(BaseModel):
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
|
||||
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
|
||||
device = self.device_torch
|
||||
scale_factor = 8
|
||||
batch_size = len(prompt)
|
||||
# Gemma expects left padding for chat-style prompts
|
||||
self.tokenizer[0].padding_side = "left"
|
||||
if self.tokenizer[0].pad_token is None:
|
||||
self.tokenizer[0].pad_token = self.tokenizer[0].eos_token
|
||||
|
||||
prompt = [p.strip() for p in prompt]
|
||||
text_inputs = self.tokenizer[0](
|
||||
prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
device=self.device_torch,
|
||||
# padding="max_length",
|
||||
padding="longest",
|
||||
max_length=1024,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
|
||||
text_input_ids = text_input_ids.to(device)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
text_encoder_outputs = self.text_encoder[0](
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self.pipeline._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer[0].padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.torch_dtype)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
|
||||
prompt_embeds = prompt_embeds.view(
|
||||
batch_size * 1, seq_len, -1
|
||||
)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(1, 1)
|
||||
|
||||
pe = PromptEmbeds([prompt_embeds, None])
|
||||
pe.attention_mask = prompt_attention_mask
|
||||
return pe
|
||||
|
||||
Reference in New Issue
Block a user