mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-23 22:17:29 +00:00
Compare commits
2 Commits
feat/glsl-
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e87858e974 | ||
|
|
da6edb5a4e |
@@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel):
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
|
||||
# Inject reference audio for ID-LoRA in-context conditioning
|
||||
ref_audio = kwargs.get("ref_audio", None)
|
||||
ref_audio_seq_len = 0
|
||||
if ref_audio is not None:
|
||||
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
|
||||
if ref_tokens.shape[0] < ax.shape[0]:
|
||||
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
|
||||
ref_audio_seq_len = ref_tokens.shape[1]
|
||||
B = ax.shape[0]
|
||||
|
||||
# Compute negative temporal positions matching ID-LoRA convention:
|
||||
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
|
||||
p = self.a_patchifier
|
||||
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
|
||||
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
|
||||
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
|
||||
time_offset = ref_end[-1].item() + tpl
|
||||
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
|
||||
|
||||
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
|
||||
additional_args["target_audio_seq_len"] = ax.shape[1]
|
||||
ax = torch.cat([ref_tokens, ax], dim=1)
|
||||
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
|
||||
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
|
||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||
@@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||
if ref_audio_seq_len > 0 and a_timestep is not None:
|
||||
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
|
||||
target_len = kwargs.get("target_audio_seq_len")
|
||||
if a_timestep.dim() <= 1:
|
||||
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
|
||||
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
|
||||
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
|
||||
if a_timestep is not None:
|
||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_flat = a_timestep_scaled.flatten()
|
||||
@@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel):
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
|
||||
# Trim reference audio tokens before unpatchification
|
||||
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||
if ref_audio_seq_len > 0:
|
||||
ax = ax[:, ref_audio_seq_len:]
|
||||
if a_embedded_timestep.shape[1] > 1:
|
||||
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
|
||||
|
||||
# Expand compressed video timestep if needed
|
||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||
v_embedded_timestep = v_embedded_timestep.expand()
|
||||
|
||||
@@ -1061,6 +1061,10 @@ class LTXAV(BaseModel):
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
ref_audio = kwargs.get("ref_audio", None)
|
||||
if ref_audio is not None:
|
||||
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
|
||||
@@ -3,6 +3,7 @@ import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import math
|
||||
import numpy as np
|
||||
@@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
|
||||
return io.NodeOutput(video_latent, audio_latent)
|
||||
|
||||
|
||||
class LTXVReferenceAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVReferenceAudio",
|
||||
display_name="LTXV Reference Audio (ID-LoRA)",
|
||||
category="conditioning/audio",
|
||||
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
|
||||
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
|
||||
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
|
||||
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
|
||||
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
|
||||
# Encode reference audio to latents and patchify
|
||||
audio_latents = audio_vae.encode(reference_audio)
|
||||
b, c, t, f = audio_latents.shape
|
||||
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
||||
ref_audio = {"tokens": ref_tokens}
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
|
||||
|
||||
# Patch model with identity guidance
|
||||
m = model.clone()
|
||||
scale = identity_guidance_scale
|
||||
model_sampling = m.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
def post_cfg_function(args):
|
||||
if scale == 0:
|
||||
return args["denoised"]
|
||||
|
||||
sigma = args["sigma"]
|
||||
sigma_ = sigma[0].item()
|
||||
if sigma_ > sigma_start or sigma_ < sigma_end:
|
||||
return args["denoised"]
|
||||
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
model_options = args["model_options"].copy()
|
||||
x = args["input"]
|
||||
|
||||
# Strip ref_audio from conditioning for the no-reference pass
|
||||
noref_cond = []
|
||||
for entry in cond:
|
||||
new_entry = entry.copy()
|
||||
mc = new_entry.get("model_conds", {}).copy()
|
||||
mc.pop("ref_audio", None)
|
||||
new_entry["model_conds"] = mc
|
||||
noref_cond.append(new_entry)
|
||||
|
||||
(pred_noref,) = comfy.samplers.calc_cond_batch(
|
||||
args["model"], [noref_cond], x, sigma, model_options
|
||||
)
|
||||
|
||||
return cfg_result + (cond_pred - pred_noref) * scale
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return io.NodeOutput(m, positive, negative)
|
||||
|
||||
|
||||
class LtxvExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
|
||||
LTXVCropGuides,
|
||||
LTXVConcatAVLatent,
|
||||
LTXVSeparateAVLatent,
|
||||
LTXVReferenceAudio,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1b6
|
||||
comfyui_manager==4.1b8
|
||||
|
||||
Reference in New Issue
Block a user