diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index d6e92c17b..a7228acac 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1369,7 +1369,36 @@ _DynamicInputFunc = Callable[ ] DYNAMIC_INPUT_LOOKUP: dict[str, _DynamicInputFunc] = {} def register_dynamic_input_func(io_type: str, func: _DynamicInputFunc): - DYNAMIC_INPUT_LOOKUP[io_type] = func + """Register a dynamic-input expansion callback. + + Accepts both the current 6-argument form and the legacy 5-argument form + (without ``live_input_types``). Legacy callables are silently wrapped so + custom nodes that registered against the older signature continue to work. + """ + try: + sig = inspect.signature(func) + param_count = sum( + 1 for p in sig.parameters.values() + if p.kind in (inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL) + ) + # 5 = legacy signature (no live_input_types). If the callable has + # *args we can't be certain, so treat it as new-style. + has_varargs = any(p.kind is inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()) + is_legacy = (param_count == 5 and not has_varargs) + except (TypeError, ValueError): + # Builtins / C-implemented callables without introspectable signatures + # are assumed to be new-style. + is_legacy = False + + if is_legacy: + _legacy = func + def _adapter(out_dict, live_inputs, value, input_type, curr_prefix, live_input_types=None): + _legacy(out_dict, live_inputs, value, input_type, curr_prefix) + DYNAMIC_INPUT_LOOKUP[io_type] = _adapter + else: + DYNAMIC_INPUT_LOOKUP[io_type] = func def get_dynamic_input_func(io_type: str) -> _DynamicInputFunc: return DYNAMIC_INPUT_LOOKUP[io_type] diff --git a/comfy_execution/type_resolver.py b/comfy_execution/type_resolver.py index 975d37097..d5d21d7a9 100644 --- a/comfy_execution/type_resolver.py +++ b/comfy_execution/type_resolver.py @@ -27,6 +27,25 @@ from typing import Any from comfy_api.latest import io from comfy_api.internal import _ComfyNodeInternal + +def _parse_link(val: Any) -> tuple[str, int] | None: + """Return (src_node_id, src_slot_idx) if ``val`` is a well-formed link. + + A link in the prompt schema is a length-2 list/tuple ``[node_id, slot_idx]`` + where ``node_id`` is a string and ``slot_idx`` is a non-negative int. + Anything else (including ``[node_id, "0"]`` from malformed API JSON) returns + ``None`` so callers can fall back to AnyType instead of crashing. + """ + if not isinstance(val, (list, tuple)) or len(val) != 2: + return None + src_node, src_slot = val[0], val[1] + if not isinstance(src_node, str): + return None + # bool is a subclass of int — reject it to avoid treating True/False as slot 1/0. + if isinstance(src_slot, bool) or not isinstance(src_slot, int): + return None + return src_node, src_slot + # Sentinel for "type is unknown / wildcard". Matches AnyType.io_type ("*"). ANY_TYPE: str = io.AnyType.io_type @@ -76,6 +95,18 @@ class TypeResolver: import nodes return nodes.NODE_CLASS_MAPPINGS.get(class_type) + def _get_class_def_for_node(self, node_id: str): + """Return (node_dict, class_def) for ``node_id``, or ``(None, None)``.""" + if not self._has_node(node_id): + return None, None + node = self._get_node(node_id) + if node is None: + return None, None + class_type = node.get("class_type") + if not isinstance(class_type, str): + return node, None + return node, self._get_class_def(class_type) + # ---- cache management ------------------------------------------------- def invalidate(self) -> None: """Clear all cached resolutions. Cheap; call after any graph mutation.""" @@ -97,8 +128,15 @@ class TypeResolver: """Return the resolved io_type string of ``node_id``'s output slot. Falls back to ``ANY_TYPE`` on cycle, depth-overflow, unknown class, - out-of-range slot, missing node, or unresolved MatchType template. + out-of-range slot, missing node, malformed link, or unresolved + MatchType template. """ + # Guard against malformed callers passing non-int slot indices (e.g. + # API JSON that sent a string). Falling back to AnyType is safer than + # raising TypeError mid-validation. + if isinstance(slot_idx, bool) or not isinstance(slot_idx, int): + return ANY_TYPE + cache_key = (node_id, slot_idx) if cache_key in self._output_cache: return self._output_cache[cache_key] @@ -113,17 +151,10 @@ class TypeResolver: return ANY_TYPE next_stack = _stack | {cache_key} - if not self._has_node(node_id): - return ANY_TYPE - - node = self._get_node(node_id) - if node is None: - return ANY_TYPE - - class_type = node.get("class_type") - class_def = self._get_class_def(class_type) if class_type is not None else None + node, class_def = self._get_class_def_for_node(node_id) if class_def is None: return ANY_TYPE + class_type = node.get("class_type") try: return_types = class_def.RETURN_TYPES @@ -193,13 +224,13 @@ class TypeResolver: val = inputs_dict.get(inp.id) if val is None: continue - if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str): - src_node, src_slot = val[0], val[1] - t = self.resolve_output_type(src_node, src_slot, stack) + link = _parse_link(val) + if link is not None: + t = self.resolve_output_type(link[0], link[1], stack) if t != ANY_TYPE: return t - # Literal value: a MatchType slot has no concrete declared type, so - # we cannot infer anything useful here. + # Literal value (or malformed link): a MatchType slot has no + # concrete declared type, so we cannot infer anything useful here. if not any_input_seen: # Schema declared a template_id with no Input bearing it. This is a # node-author bug; warn once. @@ -212,17 +243,17 @@ class TypeResolver: def is_output_list(self, node_id: str, slot_idx: int) -> bool: """Whether the source slot is declared as a list output (``OUTPUT_IS_LIST[idx]``).""" + if isinstance(slot_idx, bool) or not isinstance(slot_idx, int): + return False cache_key = (node_id, slot_idx) if cache_key in self._is_output_list_cache: return self._is_output_list_cache[cache_key] result = False - node = self._get_node(node_id) - if node is not None: - class_def = self._get_class_def(node.get("class_type")) - if class_def is not None: - lst = getattr(class_def, "OUTPUT_IS_LIST", None) - if lst is not None and 0 <= slot_idx < len(lst): - result = bool(lst[slot_idx]) + _, class_def = self._get_class_def_for_node(node_id) + if class_def is not None: + lst = getattr(class_def, "OUTPUT_IS_LIST", None) + if lst is not None and 0 <= slot_idx < len(lst): + result = bool(lst[slot_idx]) self._is_output_list_cache[cache_key] = result return result @@ -234,7 +265,8 @@ class TypeResolver: * If the value is a literal, return the declared slot's effective io_type (peeling dynamic-input wrappers — e.g. an Autogrow-of-Image slot resolves to ``IMAGE``, not ``COMFY_AUTOGROW_V3``). - * If the value is missing or the slot is unknown, return ``ANY_TYPE``. + * If the value is missing, malformed, or the slot is unknown, return + ``ANY_TYPE``. """ node = self._get_node(node_id) if node is None: @@ -242,9 +274,9 @@ class TypeResolver: inputs = node.get("inputs", {}) or {} if input_id not in inputs: return ANY_TYPE - val = inputs[input_id] - if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str): - return self.resolve_output_type(val[0], val[1]) + link = _parse_link(inputs[input_id]) + if link is not None: + return self.resolve_output_type(link[0], link[1]) return self.get_declared_slot_io_type(node_id, input_id) def is_input_list(self, node_id: str, input_id: str) -> bool: @@ -252,10 +284,10 @@ class TypeResolver: node = self._get_node(node_id) if node is None: return False - val = (node.get("inputs", {}) or {}).get(input_id) - if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str): - return self.is_output_list(val[0], val[1]) - return False + link = _parse_link((node.get("inputs", {}) or {}).get(input_id)) + if link is None: + return False + return self.is_output_list(link[0], link[1]) def get_declared_slot_io_type(self, node_id: str, input_id: str) -> str: """Return the effective declared io_type of a node's input slot. @@ -269,10 +301,7 @@ class TypeResolver: * DynamicCombo / unsupported → ``ANY_TYPE`` (the combo key is itself dynamic, not a meaningful type for consumers) """ - node = self._get_node(node_id) - if node is None: - return ANY_TYPE - class_def = self._get_class_def(node.get("class_type")) + _, class_def = self._get_class_def_for_node(node_id) if class_def is None: return ANY_TYPE diff --git a/tests-unit/comfy_api_test/test_register_dynamic_input_func.py b/tests-unit/comfy_api_test/test_register_dynamic_input_func.py new file mode 100644 index 000000000..afabf9fe7 --- /dev/null +++ b/tests-unit/comfy_api_test/test_register_dynamic_input_func.py @@ -0,0 +1,79 @@ +"""Backward-compat tests for ``register_dynamic_input_func``. + +When ``live_input_types`` was added as a sixth argument to the dynamic-input +expansion callback, third-party custom nodes that registered against the +original 5-argument signature would otherwise crash with ``TypeError`` the +first time their input was expanded. ``register_dynamic_input_func`` wraps +such legacy callables transparently. +""" + +from __future__ import annotations + +from comfy_api.latest import _io + + +def test_legacy_5arg_callback_is_wrapped_transparently(): + received = {} + + def legacy(out_dict, live_inputs, value, input_type, curr_prefix): + received["args"] = (out_dict, live_inputs, value, input_type, curr_prefix) + + io_type = "TEST_LEGACY_5ARG_V3" + try: + _io.register_dynamic_input_func(io_type, legacy) + fn = _io.get_dynamic_input_func(io_type) + + # Caller invokes with 6 arguments (current signature). The shim must + # strip the trailing live_input_types argument before delegating. + fn({"required": {}}, {"a": 1}, ("X", {}), "required", ["p"], {"a": "INT"}) + + assert received["args"] == ( + {"required": {}}, {"a": 1}, ("X", {}), "required", ["p"] + ) + finally: + _io.DYNAMIC_INPUT_LOOKUP.pop(io_type, None) + + +def test_new_6arg_callback_passes_live_input_types_through(): + received = {} + + def modern(out_dict, live_inputs, value, input_type, curr_prefix, live_input_types=None): + received["live_input_types"] = live_input_types + + io_type = "TEST_MODERN_6ARG_V3" + try: + _io.register_dynamic_input_func(io_type, modern) + fn = _io.get_dynamic_input_func(io_type) + fn({}, {}, ("X", {}), "required", None, {"foo": "IMAGE"}) + assert received["live_input_types"] == {"foo": "IMAGE"} + finally: + _io.DYNAMIC_INPUT_LOOKUP.pop(io_type, None) + + +def test_callable_with_uninspectable_signature_assumed_modern(): + """``functools.partial`` and C builtins may have no introspectable signature. + + The shim must not blow up; falling back to the new signature is the safe + choice (we get a clean TypeError if the callable really is too old, which + is no worse than the pre-shim behavior). + """ + calls = [] + + class _CallableObj: + # Lambdas / objects with __call__ are introspectable; partial with + # opaque builtins are not. Simulate the latter by raising in __signature__. + def __call__(self, *args, **kwargs): + calls.append((args, kwargs)) + + @property + def __signature__(self): + raise ValueError("uninspectable") + + io_type = "TEST_UNINSPECTABLE_V3" + try: + _io.register_dynamic_input_func(io_type, _CallableObj()) + fn = _io.get_dynamic_input_func(io_type) + fn({}, {}, ("X", {}), "required", None, {"x": "INT"}) + assert calls and len(calls[0][0]) == 6 + finally: + _io.DYNAMIC_INPUT_LOOKUP.pop(io_type, None) diff --git a/tests-unit/execution_test/test_type_resolver.py b/tests-unit/execution_test/test_type_resolver.py index f79384cc0..e89874ca0 100644 --- a/tests-unit/execution_test/test_type_resolver.py +++ b/tests-unit/execution_test/test_type_resolver.py @@ -360,6 +360,52 @@ def test_invalidate_clears_cache(fake_nodes_module, TypeResolver): assert r.resolve_output_type("n1", 0) == "LATENT" +# --------------------------------------------------------------------------- +# Malformed input robustness +# --------------------------------------------------------------------------- + +def test_malformed_link_does_not_crash(fake_nodes_module, TypeResolver): + """A link with a non-int slot index must not raise; resolver returns AnyType.""" + fake_nodes_module["Src"] = _v1_node(("IMAGE",)) + fake_nodes_module["Sink"] = _v1_node(("INT",), {"required": {"x": ("*",)}}) + prompt = { + "src": {"class_type": "Src", "inputs": {}}, + # slot index sent as a string (common API JSON mistake) + "sink": {"class_type": "Sink", "inputs": {"x": ["src", "0"]}}, + } + r = TypeResolver(prompt) + # Falls back to declared slot type (still "*"), no exception. + assert r.resolve_input_type("sink", "x") == "*" + + +def test_malformed_link_wrong_arity_does_not_crash(fake_nodes_module, TypeResolver): + fake_nodes_module["Src"] = _v1_node(("IMAGE",)) + fake_nodes_module["Sink"] = _v1_node(("INT",), {"required": {"x": ("*",)}}) + prompt = { + "src": {"class_type": "Src", "inputs": {}}, + "sink": {"class_type": "Sink", "inputs": {"x": ["src"]}}, # arity 1 + } + r = TypeResolver(prompt) + assert r.resolve_input_type("sink", "x") == "*" + + +def test_direct_resolve_output_type_with_bad_slot_idx_returns_any(fake_nodes_module, TypeResolver): + fake_nodes_module["Src"] = _v1_node(("IMAGE",)) + prompt = {"src": {"class_type": "Src", "inputs": {}}} + r = TypeResolver(prompt) + # type-wise these should be unreachable through normal validation but the + # resolver must still degrade gracefully. + assert r.resolve_output_type("src", "0") == "*" + assert r.resolve_output_type("src", True) == "*" # bool is a subclass of int + assert r.is_output_list("src", "0") is False + + +def test_non_string_class_type_returns_any(fake_nodes_module, TypeResolver): + prompt = {"n1": {"class_type": 42, "inputs": {}}} + r = TypeResolver(prompt) + assert r.resolve_output_type("n1", 0) == "*" + + def test_invalidate_node_only_clears_that_node(fake_nodes_module, TypeResolver): fake_nodes_module["SrcA"] = _v1_node(("IMAGE",)) fake_nodes_module["SrcB"] = _v1_node(("LATENT",))