TypeResolver: address code review (link parsing + slot_idx guard + back-compat shim)

* Add _parse_link helper validating both node_id (str) and slot_idx (int,
  rejecting bool) so malformed API JSON (e.g. ['n1', '0']) degrades to
  AnyType instead of crashing with TypeError.
* Add slot_idx type guards in resolve_output_type and is_output_list.
* Extract _get_class_def_for_node helper to dedupe node/class lookup
  across resolve_output_type, is_output_list, get_declared_slot_io_type.
* register_dynamic_input_func now detects 5-argument legacy callables
  via inspect.signature and silently wraps them; preserves backward
  compatibility for any custom node that registered its own dynamic
  input expansion against the pre-live_input_types signature.
* Tests: malformed link (str slot idx, wrong arity), bad slot type
  directly to resolve_output_type, non-string class_type. Tests for
  the legacy 5-arg shim and the modern 6-arg passthrough, including
  callables with uninspectable signatures.

Amp-Thread-ID: https://ampcode.com/threads/T-019e8568-f382-743d-a97f-0de3ff29d501
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski
2026-06-01 16:30:44 -07:00
parent 19390c112a
commit e01b335e39
4 changed files with 218 additions and 35 deletions

View File

@@ -1369,7 +1369,36 @@ _DynamicInputFunc = Callable[
]
DYNAMIC_INPUT_LOOKUP: dict[str, _DynamicInputFunc] = {}
def register_dynamic_input_func(io_type: str, func: _DynamicInputFunc):
DYNAMIC_INPUT_LOOKUP[io_type] = func
"""Register a dynamic-input expansion callback.
Accepts both the current 6-argument form and the legacy 5-argument form
(without ``live_input_types``). Legacy callables are silently wrapped so
custom nodes that registered against the older signature continue to work.
"""
try:
sig = inspect.signature(func)
param_count = sum(
1 for p in sig.parameters.values()
if p.kind in (inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL)
)
# 5 = legacy signature (no live_input_types). If the callable has
# *args we can't be certain, so treat it as new-style.
has_varargs = any(p.kind is inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values())
is_legacy = (param_count == 5 and not has_varargs)
except (TypeError, ValueError):
# Builtins / C-implemented callables without introspectable signatures
# are assumed to be new-style.
is_legacy = False
if is_legacy:
_legacy = func
def _adapter(out_dict, live_inputs, value, input_type, curr_prefix, live_input_types=None):
_legacy(out_dict, live_inputs, value, input_type, curr_prefix)
DYNAMIC_INPUT_LOOKUP[io_type] = _adapter
else:
DYNAMIC_INPUT_LOOKUP[io_type] = func
def get_dynamic_input_func(io_type: str) -> _DynamicInputFunc:
return DYNAMIC_INPUT_LOOKUP[io_type]

View File

@@ -27,6 +27,25 @@ from typing import Any
from comfy_api.latest import io
from comfy_api.internal import _ComfyNodeInternal
def _parse_link(val: Any) -> tuple[str, int] | None:
"""Return (src_node_id, src_slot_idx) if ``val`` is a well-formed link.
A link in the prompt schema is a length-2 list/tuple ``[node_id, slot_idx]``
where ``node_id`` is a string and ``slot_idx`` is a non-negative int.
Anything else (including ``[node_id, "0"]`` from malformed API JSON) returns
``None`` so callers can fall back to AnyType instead of crashing.
"""
if not isinstance(val, (list, tuple)) or len(val) != 2:
return None
src_node, src_slot = val[0], val[1]
if not isinstance(src_node, str):
return None
# bool is a subclass of int — reject it to avoid treating True/False as slot 1/0.
if isinstance(src_slot, bool) or not isinstance(src_slot, int):
return None
return src_node, src_slot
# Sentinel for "type is unknown / wildcard". Matches AnyType.io_type ("*").
ANY_TYPE: str = io.AnyType.io_type
@@ -76,6 +95,18 @@ class TypeResolver:
import nodes
return nodes.NODE_CLASS_MAPPINGS.get(class_type)
def _get_class_def_for_node(self, node_id: str):
"""Return (node_dict, class_def) for ``node_id``, or ``(None, None)``."""
if not self._has_node(node_id):
return None, None
node = self._get_node(node_id)
if node is None:
return None, None
class_type = node.get("class_type")
if not isinstance(class_type, str):
return node, None
return node, self._get_class_def(class_type)
# ---- cache management -------------------------------------------------
def invalidate(self) -> None:
"""Clear all cached resolutions. Cheap; call after any graph mutation."""
@@ -97,8 +128,15 @@ class TypeResolver:
"""Return the resolved io_type string of ``node_id``'s output slot.
Falls back to ``ANY_TYPE`` on cycle, depth-overflow, unknown class,
out-of-range slot, missing node, or unresolved MatchType template.
out-of-range slot, missing node, malformed link, or unresolved
MatchType template.
"""
# Guard against malformed callers passing non-int slot indices (e.g.
# API JSON that sent a string). Falling back to AnyType is safer than
# raising TypeError mid-validation.
if isinstance(slot_idx, bool) or not isinstance(slot_idx, int):
return ANY_TYPE
cache_key = (node_id, slot_idx)
if cache_key in self._output_cache:
return self._output_cache[cache_key]
@@ -113,17 +151,10 @@ class TypeResolver:
return ANY_TYPE
next_stack = _stack | {cache_key}
if not self._has_node(node_id):
return ANY_TYPE
node = self._get_node(node_id)
if node is None:
return ANY_TYPE
class_type = node.get("class_type")
class_def = self._get_class_def(class_type) if class_type is not None else None
node, class_def = self._get_class_def_for_node(node_id)
if class_def is None:
return ANY_TYPE
class_type = node.get("class_type")
try:
return_types = class_def.RETURN_TYPES
@@ -193,13 +224,13 @@ class TypeResolver:
val = inputs_dict.get(inp.id)
if val is None:
continue
if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str):
src_node, src_slot = val[0], val[1]
t = self.resolve_output_type(src_node, src_slot, stack)
link = _parse_link(val)
if link is not None:
t = self.resolve_output_type(link[0], link[1], stack)
if t != ANY_TYPE:
return t
# Literal value: a MatchType slot has no concrete declared type, so
# we cannot infer anything useful here.
# Literal value (or malformed link): a MatchType slot has no
# concrete declared type, so we cannot infer anything useful here.
if not any_input_seen:
# Schema declared a template_id with no Input bearing it. This is a
# node-author bug; warn once.
@@ -212,17 +243,17 @@ class TypeResolver:
def is_output_list(self, node_id: str, slot_idx: int) -> bool:
"""Whether the source slot is declared as a list output (``OUTPUT_IS_LIST[idx]``)."""
if isinstance(slot_idx, bool) or not isinstance(slot_idx, int):
return False
cache_key = (node_id, slot_idx)
if cache_key in self._is_output_list_cache:
return self._is_output_list_cache[cache_key]
result = False
node = self._get_node(node_id)
if node is not None:
class_def = self._get_class_def(node.get("class_type"))
if class_def is not None:
lst = getattr(class_def, "OUTPUT_IS_LIST", None)
if lst is not None and 0 <= slot_idx < len(lst):
result = bool(lst[slot_idx])
_, class_def = self._get_class_def_for_node(node_id)
if class_def is not None:
lst = getattr(class_def, "OUTPUT_IS_LIST", None)
if lst is not None and 0 <= slot_idx < len(lst):
result = bool(lst[slot_idx])
self._is_output_list_cache[cache_key] = result
return result
@@ -234,7 +265,8 @@ class TypeResolver:
* If the value is a literal, return the declared slot's effective
io_type (peeling dynamic-input wrappers — e.g. an Autogrow-of-Image
slot resolves to ``IMAGE``, not ``COMFY_AUTOGROW_V3``).
* If the value is missing or the slot is unknown, return ``ANY_TYPE``.
* If the value is missing, malformed, or the slot is unknown, return
``ANY_TYPE``.
"""
node = self._get_node(node_id)
if node is None:
@@ -242,9 +274,9 @@ class TypeResolver:
inputs = node.get("inputs", {}) or {}
if input_id not in inputs:
return ANY_TYPE
val = inputs[input_id]
if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str):
return self.resolve_output_type(val[0], val[1])
link = _parse_link(inputs[input_id])
if link is not None:
return self.resolve_output_type(link[0], link[1])
return self.get_declared_slot_io_type(node_id, input_id)
def is_input_list(self, node_id: str, input_id: str) -> bool:
@@ -252,10 +284,10 @@ class TypeResolver:
node = self._get_node(node_id)
if node is None:
return False
val = (node.get("inputs", {}) or {}).get(input_id)
if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str):
return self.is_output_list(val[0], val[1])
return False
link = _parse_link((node.get("inputs", {}) or {}).get(input_id))
if link is None:
return False
return self.is_output_list(link[0], link[1])
def get_declared_slot_io_type(self, node_id: str, input_id: str) -> str:
"""Return the effective declared io_type of a node's input slot.
@@ -269,10 +301,7 @@ class TypeResolver:
* DynamicCombo / unsupported → ``ANY_TYPE`` (the combo key is itself
dynamic, not a meaningful type for consumers)
"""
node = self._get_node(node_id)
if node is None:
return ANY_TYPE
class_def = self._get_class_def(node.get("class_type"))
_, class_def = self._get_class_def_for_node(node_id)
if class_def is None:
return ANY_TYPE

View File

@@ -0,0 +1,79 @@
"""Backward-compat tests for ``register_dynamic_input_func``.
When ``live_input_types`` was added as a sixth argument to the dynamic-input
expansion callback, third-party custom nodes that registered against the
original 5-argument signature would otherwise crash with ``TypeError`` the
first time their input was expanded. ``register_dynamic_input_func`` wraps
such legacy callables transparently.
"""
from __future__ import annotations
from comfy_api.latest import _io
def test_legacy_5arg_callback_is_wrapped_transparently():
received = {}
def legacy(out_dict, live_inputs, value, input_type, curr_prefix):
received["args"] = (out_dict, live_inputs, value, input_type, curr_prefix)
io_type = "TEST_LEGACY_5ARG_V3"
try:
_io.register_dynamic_input_func(io_type, legacy)
fn = _io.get_dynamic_input_func(io_type)
# Caller invokes with 6 arguments (current signature). The shim must
# strip the trailing live_input_types argument before delegating.
fn({"required": {}}, {"a": 1}, ("X", {}), "required", ["p"], {"a": "INT"})
assert received["args"] == (
{"required": {}}, {"a": 1}, ("X", {}), "required", ["p"]
)
finally:
_io.DYNAMIC_INPUT_LOOKUP.pop(io_type, None)
def test_new_6arg_callback_passes_live_input_types_through():
received = {}
def modern(out_dict, live_inputs, value, input_type, curr_prefix, live_input_types=None):
received["live_input_types"] = live_input_types
io_type = "TEST_MODERN_6ARG_V3"
try:
_io.register_dynamic_input_func(io_type, modern)
fn = _io.get_dynamic_input_func(io_type)
fn({}, {}, ("X", {}), "required", None, {"foo": "IMAGE"})
assert received["live_input_types"] == {"foo": "IMAGE"}
finally:
_io.DYNAMIC_INPUT_LOOKUP.pop(io_type, None)
def test_callable_with_uninspectable_signature_assumed_modern():
"""``functools.partial`` and C builtins may have no introspectable signature.
The shim must not blow up; falling back to the new signature is the safe
choice (we get a clean TypeError if the callable really is too old, which
is no worse than the pre-shim behavior).
"""
calls = []
class _CallableObj:
# Lambdas / objects with __call__ are introspectable; partial with
# opaque builtins are not. Simulate the latter by raising in __signature__.
def __call__(self, *args, **kwargs):
calls.append((args, kwargs))
@property
def __signature__(self):
raise ValueError("uninspectable")
io_type = "TEST_UNINSPECTABLE_V3"
try:
_io.register_dynamic_input_func(io_type, _CallableObj())
fn = _io.get_dynamic_input_func(io_type)
fn({}, {}, ("X", {}), "required", None, {"x": "INT"})
assert calls and len(calls[0][0]) == 6
finally:
_io.DYNAMIC_INPUT_LOOKUP.pop(io_type, None)

View File

@@ -360,6 +360,52 @@ def test_invalidate_clears_cache(fake_nodes_module, TypeResolver):
assert r.resolve_output_type("n1", 0) == "LATENT"
# ---------------------------------------------------------------------------
# Malformed input robustness
# ---------------------------------------------------------------------------
def test_malformed_link_does_not_crash(fake_nodes_module, TypeResolver):
"""A link with a non-int slot index must not raise; resolver returns AnyType."""
fake_nodes_module["Src"] = _v1_node(("IMAGE",))
fake_nodes_module["Sink"] = _v1_node(("INT",), {"required": {"x": ("*",)}})
prompt = {
"src": {"class_type": "Src", "inputs": {}},
# slot index sent as a string (common API JSON mistake)
"sink": {"class_type": "Sink", "inputs": {"x": ["src", "0"]}},
}
r = TypeResolver(prompt)
# Falls back to declared slot type (still "*"), no exception.
assert r.resolve_input_type("sink", "x") == "*"
def test_malformed_link_wrong_arity_does_not_crash(fake_nodes_module, TypeResolver):
fake_nodes_module["Src"] = _v1_node(("IMAGE",))
fake_nodes_module["Sink"] = _v1_node(("INT",), {"required": {"x": ("*",)}})
prompt = {
"src": {"class_type": "Src", "inputs": {}},
"sink": {"class_type": "Sink", "inputs": {"x": ["src"]}}, # arity 1
}
r = TypeResolver(prompt)
assert r.resolve_input_type("sink", "x") == "*"
def test_direct_resolve_output_type_with_bad_slot_idx_returns_any(fake_nodes_module, TypeResolver):
fake_nodes_module["Src"] = _v1_node(("IMAGE",))
prompt = {"src": {"class_type": "Src", "inputs": {}}}
r = TypeResolver(prompt)
# type-wise these should be unreachable through normal validation but the
# resolver must still degrade gracefully.
assert r.resolve_output_type("src", "0") == "*"
assert r.resolve_output_type("src", True) == "*" # bool is a subclass of int
assert r.is_output_list("src", "0") is False
def test_non_string_class_type_returns_any(fake_nodes_module, TypeResolver):
prompt = {"n1": {"class_type": 42, "inputs": {}}}
r = TypeResolver(prompt)
assert r.resolve_output_type("n1", 0) == "*"
def test_invalidate_node_only_clears_that_node(fake_nodes_module, TypeResolver):
fake_nodes_module["SrcA"] = _v1_node(("IMAGE",))
fake_nodes_module["SrcB"] = _v1_node(("LATENT",))