mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 22:30:00 +00:00
feat: per-guide attention strength control in self-attention (#12518)
Implements per-guide attention attenuation via log-space additive bias in self-attention. Each guide reference tracks its own strength and optional spatial mask in conditioning metadata (guide_attention_entries).
This commit is contained in:
@@ -65,6 +65,42 @@ from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
class _CONDGuideEntries(comfy.conds.CONDConstant):
|
||||
"""CONDConstant subclass that safely compares guide_attention_entries.
|
||||
|
||||
guide_attention_entries may contain ``pixel_mask`` tensors. The default
|
||||
``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError``
|
||||
on tensors. This subclass performs a structural comparison instead.
|
||||
"""
|
||||
|
||||
def can_concat(self, other):
|
||||
if not isinstance(other, _CONDGuideEntries):
|
||||
return False
|
||||
a, b = self.cond, other.cond
|
||||
if len(a) != len(b):
|
||||
return False
|
||||
for ea, eb in zip(a, b):
|
||||
if ea["pre_filter_count"] != eb["pre_filter_count"]:
|
||||
return False
|
||||
if ea["strength"] != eb["strength"]:
|
||||
return False
|
||||
if ea.get("latent_shape") != eb.get("latent_shape"):
|
||||
return False
|
||||
a_has = ea.get("pixel_mask") is not None
|
||||
b_has = eb.get("pixel_mask") is not None
|
||||
if a_has != b_has:
|
||||
return False
|
||||
if a_has:
|
||||
pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"]
|
||||
if pm_a is not pm_b:
|
||||
if (pm_a.shape != pm_b.shape
|
||||
or pm_a.device != pm_b.device
|
||||
or pm_a.dtype != pm_b.dtype
|
||||
or not torch.equal(pm_a, pm_b)):
|
||||
return False
|
||||
return True
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
@@ -974,6 +1010,10 @@ class LTXV(BaseModel):
|
||||
if keyframe_idxs is not None:
|
||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||
@@ -1026,6 +1066,10 @@ class LTXAV(BaseModel):
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user