diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 2c6954ecd..2b080aaeb 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import ( LTXVModel, ) from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import comfy.ldm.common_dit class CompressedTimestep: @@ -450,6 +451,29 @@ class LTXAVModel(LTXVModel): operations=self.operations, ) + self.audio_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=self.operations, + ) + + self.video_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=self.operations, + ) + + def preprocess_text_embeds(self, context): + if context.shape[-1] == self.caption_channels * 2: + return context + out_vid = self.video_embeddings_connector(context)[0] + out_audio = self.audio_embeddings_connector(context)[0] + return torch.concat((out_vid, out_audio), dim=-1) + def _init_transformer_blocks(self, device, dtype, **kwargs): """Initialize transformer blocks for LTXAV.""" self.transformer_blocks = nn.ModuleList( diff --git a/comfy/model_base.py b/comfy/model_base.py index 9dcef8741..2f49578f6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -988,10 +988,14 @@ class LTXAV(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) + device = kwargs["device"] + if attention_mask is not None: out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: + if hasattr(self.diffusion_model, "preprocess_text_embeds"): + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference())) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 82fbacf59..e2ce22e37 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -3,7 +3,6 @@ import os from transformers import T5TokenizerFast from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.genmo -from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import torch import comfy.utils import math @@ -109,22 +108,6 @@ class LTXAVTEModel(torch.nn.Module): operations = self.gemma3_12b.operations # TODO self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) - self.audio_embeddings_connector = Embeddings1DConnector( - split_rope=True, - double_precision_rope=True, - dtype=dtype, - device=device, - operations=operations, - ) - - self.video_embeddings_connector = Embeddings1DConnector( - split_rope=True, - double_precision_rope=True, - dtype=dtype, - device=device, - operations=operations, - ) - def set_clip_options(self, options): self.execution_device = options.get("execution_device", self.execution_device) self.gemma3_12b.set_clip_options(options) @@ -146,10 +129,6 @@ class LTXAVTEModel(torch.nn.Module): out = out.reshape((out.shape[0], out.shape[1], -1)) out = self.text_embedding_projection(out) out = out.float() - out_vid = self.video_embeddings_connector(out)[0] - out_audio = self.audio_embeddings_connector(out)[0] - out = torch.concat((out_vid, out_audio), dim=-1) - return out.to(out_device), pooled def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): @@ -159,14 +138,14 @@ class LTXAVTEModel(torch.nn.Module): if "model.layers.47.self_attn.q_norm.weight" in sd: return self.gemma3_12b.load_sd(sd) else: - sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True) if len(sdo) == 0: sdo = sd missing_all = [] unexpected_all = [] - for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: + for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]: component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} if component_sd: missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))