diff --git a/comfy/conds.py b/comfy/conds.py index 5af3e93ea..55d8cdd78 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -4,6 +4,25 @@ import comfy.utils import logging +def is_equal(x, y): + if torch.is_tensor(x) and torch.is_tensor(y): + return torch.equal(x, y) + elif isinstance(x, dict) and isinstance(y, dict): + if x.keys() != y.keys(): + return False + return all(is_equal(x[k], y[k]) for k in x) + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + if type(x) is not type(y) or len(x) != len(y): + return False + return all(is_equal(a, b) for a, b in zip(x, y)) + else: + try: + return x == y + except Exception: + logging.warning("comparison issue with COND") + return False + + class CONDRegular: def __init__(self, cond): self.cond = cond @@ -84,7 +103,7 @@ class CONDConstant(CONDRegular): return self._copy_with(self.cond) def can_concat(self, other): - if self.cond != other.cond: + if not is_equal(self.cond, other.cond): return False return True diff --git a/comfy/model_base.py b/comfy/model_base.py index 04695c079..8f852e3c6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -65,42 +65,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher - -class _CONDGuideEntries(comfy.conds.CONDConstant): - """CONDConstant subclass that safely compares guide_attention_entries. - - guide_attention_entries may contain ``pixel_mask`` tensors. The default - ``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError`` - on tensors. This subclass performs a structural comparison instead. - """ - - def can_concat(self, other): - if not isinstance(other, _CONDGuideEntries): - return False - a, b = self.cond, other.cond - if len(a) != len(b): - return False - for ea, eb in zip(a, b): - if ea["pre_filter_count"] != eb["pre_filter_count"]: - return False - if ea["strength"] != eb["strength"]: - return False - if ea.get("latent_shape") != eb.get("latent_shape"): - return False - a_has = ea.get("pixel_mask") is not None - b_has = eb.get("pixel_mask") is not None - if a_has != b_has: - return False - if a_has: - pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"] - if pm_a is not pm_b: - if (pm_a.shape != pm_b.shape - or pm_a.device != pm_b.device - or pm_a.dtype != pm_b.dtype - or not torch.equal(pm_a, pm_b)): - return False - return True - class ModelType(Enum): EPS = 1 V_PREDICTION = 2 @@ -1012,7 +976,7 @@ class LTXV(BaseModel): guide_attention_entries = kwargs.get("guide_attention_entries", None) if guide_attention_entries is not None: - out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) return out @@ -1068,7 +1032,7 @@ class LTXAV(BaseModel): guide_attention_entries = kwargs.get("guide_attention_entries", None) if guide_attention_entries is not None: - out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) return out