feat: per-guide attention strength control in self-attention (#12518)

Implements per-guide attention attenuation via log-space additive bias
in self-attention. Each guide reference tracks its own strength and
optional spatial mask in conditioning metadata (guide_attention_entries).
This commit is contained in:
Tavi Halperin
2026-02-26 08:25:23 +02:00
committed by GitHub
parent 907e5dcbbf
commit a4522017c5
4 changed files with 352 additions and 12 deletions

View File

@@ -65,6 +65,42 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class _CONDGuideEntries(comfy.conds.CONDConstant):
"""CONDConstant subclass that safely compares guide_attention_entries.
guide_attention_entries may contain ``pixel_mask`` tensors. The default
``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError``
on tensors. This subclass performs a structural comparison instead.
"""
def can_concat(self, other):
if not isinstance(other, _CONDGuideEntries):
return False
a, b = self.cond, other.cond
if len(a) != len(b):
return False
for ea, eb in zip(a, b):
if ea["pre_filter_count"] != eb["pre_filter_count"]:
return False
if ea["strength"] != eb["strength"]:
return False
if ea.get("latent_shape") != eb.get("latent_shape"):
return False
a_has = ea.get("pixel_mask") is not None
b_has = eb.get("pixel_mask") is not None
if a_has != b_has:
return False
if a_has:
pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"]
if pm_a is not pm_b:
if (pm_a.shape != pm_b.shape
or pm_a.device != pm_b.device
or pm_a.dtype != pm_b.dtype
or not torch.equal(pm_a, pm_b)):
return False
return True
class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
@@ -974,6 +1010,10 @@ class LTXV(BaseModel):
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
@@ -1026,6 +1066,10 @@ class LTXAV(BaseModel):
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):