mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-08 23:38:21 +00:00
TypeResolver: address code review (link parsing + slot_idx guard + back-compat shim)
* Add _parse_link helper validating both node_id (str) and slot_idx (int, rejecting bool) so malformed API JSON (e.g. ['n1', '0']) degrades to AnyType instead of crashing with TypeError. * Add slot_idx type guards in resolve_output_type and is_output_list. * Extract _get_class_def_for_node helper to dedupe node/class lookup across resolve_output_type, is_output_list, get_declared_slot_io_type. * register_dynamic_input_func now detects 5-argument legacy callables via inspect.signature and silently wraps them; preserves backward compatibility for any custom node that registered its own dynamic input expansion against the pre-live_input_types signature. * Tests: malformed link (str slot idx, wrong arity), bad slot type directly to resolve_output_type, non-string class_type. Tests for the legacy 5-arg shim and the modern 6-arg passthrough, including callables with uninspectable signatures. Amp-Thread-ID: https://ampcode.com/threads/T-019e8568-f382-743d-a97f-0de3ff29d501 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -1369,7 +1369,36 @@ _DynamicInputFunc = Callable[
|
||||
]
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, _DynamicInputFunc] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: _DynamicInputFunc):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
"""Register a dynamic-input expansion callback.
|
||||
|
||||
Accepts both the current 6-argument form and the legacy 5-argument form
|
||||
(without ``live_input_types``). Legacy callables are silently wrapped so
|
||||
custom nodes that registered against the older signature continue to work.
|
||||
"""
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
param_count = sum(
|
||||
1 for p in sig.parameters.values()
|
||||
if p.kind in (inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.VAR_POSITIONAL)
|
||||
)
|
||||
# 5 = legacy signature (no live_input_types). If the callable has
|
||||
# *args we can't be certain, so treat it as new-style.
|
||||
has_varargs = any(p.kind is inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values())
|
||||
is_legacy = (param_count == 5 and not has_varargs)
|
||||
except (TypeError, ValueError):
|
||||
# Builtins / C-implemented callables without introspectable signatures
|
||||
# are assumed to be new-style.
|
||||
is_legacy = False
|
||||
|
||||
if is_legacy:
|
||||
_legacy = func
|
||||
def _adapter(out_dict, live_inputs, value, input_type, curr_prefix, live_input_types=None):
|
||||
_legacy(out_dict, live_inputs, value, input_type, curr_prefix)
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = _adapter
|
||||
else:
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
|
||||
def get_dynamic_input_func(io_type: str) -> _DynamicInputFunc:
|
||||
return DYNAMIC_INPUT_LOOKUP[io_type]
|
||||
|
||||
@@ -27,6 +27,25 @@ from typing import Any
|
||||
from comfy_api.latest import io
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
|
||||
def _parse_link(val: Any) -> tuple[str, int] | None:
|
||||
"""Return (src_node_id, src_slot_idx) if ``val`` is a well-formed link.
|
||||
|
||||
A link in the prompt schema is a length-2 list/tuple ``[node_id, slot_idx]``
|
||||
where ``node_id`` is a string and ``slot_idx`` is a non-negative int.
|
||||
Anything else (including ``[node_id, "0"]`` from malformed API JSON) returns
|
||||
``None`` so callers can fall back to AnyType instead of crashing.
|
||||
"""
|
||||
if not isinstance(val, (list, tuple)) or len(val) != 2:
|
||||
return None
|
||||
src_node, src_slot = val[0], val[1]
|
||||
if not isinstance(src_node, str):
|
||||
return None
|
||||
# bool is a subclass of int — reject it to avoid treating True/False as slot 1/0.
|
||||
if isinstance(src_slot, bool) or not isinstance(src_slot, int):
|
||||
return None
|
||||
return src_node, src_slot
|
||||
|
||||
# Sentinel for "type is unknown / wildcard". Matches AnyType.io_type ("*").
|
||||
ANY_TYPE: str = io.AnyType.io_type
|
||||
|
||||
@@ -76,6 +95,18 @@ class TypeResolver:
|
||||
import nodes
|
||||
return nodes.NODE_CLASS_MAPPINGS.get(class_type)
|
||||
|
||||
def _get_class_def_for_node(self, node_id: str):
|
||||
"""Return (node_dict, class_def) for ``node_id``, or ``(None, None)``."""
|
||||
if not self._has_node(node_id):
|
||||
return None, None
|
||||
node = self._get_node(node_id)
|
||||
if node is None:
|
||||
return None, None
|
||||
class_type = node.get("class_type")
|
||||
if not isinstance(class_type, str):
|
||||
return node, None
|
||||
return node, self._get_class_def(class_type)
|
||||
|
||||
# ---- cache management -------------------------------------------------
|
||||
def invalidate(self) -> None:
|
||||
"""Clear all cached resolutions. Cheap; call after any graph mutation."""
|
||||
@@ -97,8 +128,15 @@ class TypeResolver:
|
||||
"""Return the resolved io_type string of ``node_id``'s output slot.
|
||||
|
||||
Falls back to ``ANY_TYPE`` on cycle, depth-overflow, unknown class,
|
||||
out-of-range slot, missing node, or unresolved MatchType template.
|
||||
out-of-range slot, missing node, malformed link, or unresolved
|
||||
MatchType template.
|
||||
"""
|
||||
# Guard against malformed callers passing non-int slot indices (e.g.
|
||||
# API JSON that sent a string). Falling back to AnyType is safer than
|
||||
# raising TypeError mid-validation.
|
||||
if isinstance(slot_idx, bool) or not isinstance(slot_idx, int):
|
||||
return ANY_TYPE
|
||||
|
||||
cache_key = (node_id, slot_idx)
|
||||
if cache_key in self._output_cache:
|
||||
return self._output_cache[cache_key]
|
||||
@@ -113,17 +151,10 @@ class TypeResolver:
|
||||
return ANY_TYPE
|
||||
next_stack = _stack | {cache_key}
|
||||
|
||||
if not self._has_node(node_id):
|
||||
return ANY_TYPE
|
||||
|
||||
node = self._get_node(node_id)
|
||||
if node is None:
|
||||
return ANY_TYPE
|
||||
|
||||
class_type = node.get("class_type")
|
||||
class_def = self._get_class_def(class_type) if class_type is not None else None
|
||||
node, class_def = self._get_class_def_for_node(node_id)
|
||||
if class_def is None:
|
||||
return ANY_TYPE
|
||||
class_type = node.get("class_type")
|
||||
|
||||
try:
|
||||
return_types = class_def.RETURN_TYPES
|
||||
@@ -193,13 +224,13 @@ class TypeResolver:
|
||||
val = inputs_dict.get(inp.id)
|
||||
if val is None:
|
||||
continue
|
||||
if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str):
|
||||
src_node, src_slot = val[0], val[1]
|
||||
t = self.resolve_output_type(src_node, src_slot, stack)
|
||||
link = _parse_link(val)
|
||||
if link is not None:
|
||||
t = self.resolve_output_type(link[0], link[1], stack)
|
||||
if t != ANY_TYPE:
|
||||
return t
|
||||
# Literal value: a MatchType slot has no concrete declared type, so
|
||||
# we cannot infer anything useful here.
|
||||
# Literal value (or malformed link): a MatchType slot has no
|
||||
# concrete declared type, so we cannot infer anything useful here.
|
||||
if not any_input_seen:
|
||||
# Schema declared a template_id with no Input bearing it. This is a
|
||||
# node-author bug; warn once.
|
||||
@@ -212,17 +243,17 @@ class TypeResolver:
|
||||
|
||||
def is_output_list(self, node_id: str, slot_idx: int) -> bool:
|
||||
"""Whether the source slot is declared as a list output (``OUTPUT_IS_LIST[idx]``)."""
|
||||
if isinstance(slot_idx, bool) or not isinstance(slot_idx, int):
|
||||
return False
|
||||
cache_key = (node_id, slot_idx)
|
||||
if cache_key in self._is_output_list_cache:
|
||||
return self._is_output_list_cache[cache_key]
|
||||
result = False
|
||||
node = self._get_node(node_id)
|
||||
if node is not None:
|
||||
class_def = self._get_class_def(node.get("class_type"))
|
||||
if class_def is not None:
|
||||
lst = getattr(class_def, "OUTPUT_IS_LIST", None)
|
||||
if lst is not None and 0 <= slot_idx < len(lst):
|
||||
result = bool(lst[slot_idx])
|
||||
_, class_def = self._get_class_def_for_node(node_id)
|
||||
if class_def is not None:
|
||||
lst = getattr(class_def, "OUTPUT_IS_LIST", None)
|
||||
if lst is not None and 0 <= slot_idx < len(lst):
|
||||
result = bool(lst[slot_idx])
|
||||
self._is_output_list_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
@@ -234,7 +265,8 @@ class TypeResolver:
|
||||
* If the value is a literal, return the declared slot's effective
|
||||
io_type (peeling dynamic-input wrappers — e.g. an Autogrow-of-Image
|
||||
slot resolves to ``IMAGE``, not ``COMFY_AUTOGROW_V3``).
|
||||
* If the value is missing or the slot is unknown, return ``ANY_TYPE``.
|
||||
* If the value is missing, malformed, or the slot is unknown, return
|
||||
``ANY_TYPE``.
|
||||
"""
|
||||
node = self._get_node(node_id)
|
||||
if node is None:
|
||||
@@ -242,9 +274,9 @@ class TypeResolver:
|
||||
inputs = node.get("inputs", {}) or {}
|
||||
if input_id not in inputs:
|
||||
return ANY_TYPE
|
||||
val = inputs[input_id]
|
||||
if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str):
|
||||
return self.resolve_output_type(val[0], val[1])
|
||||
link = _parse_link(inputs[input_id])
|
||||
if link is not None:
|
||||
return self.resolve_output_type(link[0], link[1])
|
||||
return self.get_declared_slot_io_type(node_id, input_id)
|
||||
|
||||
def is_input_list(self, node_id: str, input_id: str) -> bool:
|
||||
@@ -252,10 +284,10 @@ class TypeResolver:
|
||||
node = self._get_node(node_id)
|
||||
if node is None:
|
||||
return False
|
||||
val = (node.get("inputs", {}) or {}).get(input_id)
|
||||
if isinstance(val, list) and len(val) == 2 and isinstance(val[0], str):
|
||||
return self.is_output_list(val[0], val[1])
|
||||
return False
|
||||
link = _parse_link((node.get("inputs", {}) or {}).get(input_id))
|
||||
if link is None:
|
||||
return False
|
||||
return self.is_output_list(link[0], link[1])
|
||||
|
||||
def get_declared_slot_io_type(self, node_id: str, input_id: str) -> str:
|
||||
"""Return the effective declared io_type of a node's input slot.
|
||||
@@ -269,10 +301,7 @@ class TypeResolver:
|
||||
* DynamicCombo / unsupported → ``ANY_TYPE`` (the combo key is itself
|
||||
dynamic, not a meaningful type for consumers)
|
||||
"""
|
||||
node = self._get_node(node_id)
|
||||
if node is None:
|
||||
return ANY_TYPE
|
||||
class_def = self._get_class_def(node.get("class_type"))
|
||||
_, class_def = self._get_class_def_for_node(node_id)
|
||||
if class_def is None:
|
||||
return ANY_TYPE
|
||||
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Backward-compat tests for ``register_dynamic_input_func``.
|
||||
|
||||
When ``live_input_types`` was added as a sixth argument to the dynamic-input
|
||||
expansion callback, third-party custom nodes that registered against the
|
||||
original 5-argument signature would otherwise crash with ``TypeError`` the
|
||||
first time their input was expanded. ``register_dynamic_input_func`` wraps
|
||||
such legacy callables transparently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy_api.latest import _io
|
||||
|
||||
|
||||
def test_legacy_5arg_callback_is_wrapped_transparently():
|
||||
received = {}
|
||||
|
||||
def legacy(out_dict, live_inputs, value, input_type, curr_prefix):
|
||||
received["args"] = (out_dict, live_inputs, value, input_type, curr_prefix)
|
||||
|
||||
io_type = "TEST_LEGACY_5ARG_V3"
|
||||
try:
|
||||
_io.register_dynamic_input_func(io_type, legacy)
|
||||
fn = _io.get_dynamic_input_func(io_type)
|
||||
|
||||
# Caller invokes with 6 arguments (current signature). The shim must
|
||||
# strip the trailing live_input_types argument before delegating.
|
||||
fn({"required": {}}, {"a": 1}, ("X", {}), "required", ["p"], {"a": "INT"})
|
||||
|
||||
assert received["args"] == (
|
||||
{"required": {}}, {"a": 1}, ("X", {}), "required", ["p"]
|
||||
)
|
||||
finally:
|
||||
_io.DYNAMIC_INPUT_LOOKUP.pop(io_type, None)
|
||||
|
||||
|
||||
def test_new_6arg_callback_passes_live_input_types_through():
|
||||
received = {}
|
||||
|
||||
def modern(out_dict, live_inputs, value, input_type, curr_prefix, live_input_types=None):
|
||||
received["live_input_types"] = live_input_types
|
||||
|
||||
io_type = "TEST_MODERN_6ARG_V3"
|
||||
try:
|
||||
_io.register_dynamic_input_func(io_type, modern)
|
||||
fn = _io.get_dynamic_input_func(io_type)
|
||||
fn({}, {}, ("X", {}), "required", None, {"foo": "IMAGE"})
|
||||
assert received["live_input_types"] == {"foo": "IMAGE"}
|
||||
finally:
|
||||
_io.DYNAMIC_INPUT_LOOKUP.pop(io_type, None)
|
||||
|
||||
|
||||
def test_callable_with_uninspectable_signature_assumed_modern():
|
||||
"""``functools.partial`` and C builtins may have no introspectable signature.
|
||||
|
||||
The shim must not blow up; falling back to the new signature is the safe
|
||||
choice (we get a clean TypeError if the callable really is too old, which
|
||||
is no worse than the pre-shim behavior).
|
||||
"""
|
||||
calls = []
|
||||
|
||||
class _CallableObj:
|
||||
# Lambdas / objects with __call__ are introspectable; partial with
|
||||
# opaque builtins are not. Simulate the latter by raising in __signature__.
|
||||
def __call__(self, *args, **kwargs):
|
||||
calls.append((args, kwargs))
|
||||
|
||||
@property
|
||||
def __signature__(self):
|
||||
raise ValueError("uninspectable")
|
||||
|
||||
io_type = "TEST_UNINSPECTABLE_V3"
|
||||
try:
|
||||
_io.register_dynamic_input_func(io_type, _CallableObj())
|
||||
fn = _io.get_dynamic_input_func(io_type)
|
||||
fn({}, {}, ("X", {}), "required", None, {"x": "INT"})
|
||||
assert calls and len(calls[0][0]) == 6
|
||||
finally:
|
||||
_io.DYNAMIC_INPUT_LOOKUP.pop(io_type, None)
|
||||
@@ -360,6 +360,52 @@ def test_invalidate_clears_cache(fake_nodes_module, TypeResolver):
|
||||
assert r.resolve_output_type("n1", 0) == "LATENT"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Malformed input robustness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_malformed_link_does_not_crash(fake_nodes_module, TypeResolver):
|
||||
"""A link with a non-int slot index must not raise; resolver returns AnyType."""
|
||||
fake_nodes_module["Src"] = _v1_node(("IMAGE",))
|
||||
fake_nodes_module["Sink"] = _v1_node(("INT",), {"required": {"x": ("*",)}})
|
||||
prompt = {
|
||||
"src": {"class_type": "Src", "inputs": {}},
|
||||
# slot index sent as a string (common API JSON mistake)
|
||||
"sink": {"class_type": "Sink", "inputs": {"x": ["src", "0"]}},
|
||||
}
|
||||
r = TypeResolver(prompt)
|
||||
# Falls back to declared slot type (still "*"), no exception.
|
||||
assert r.resolve_input_type("sink", "x") == "*"
|
||||
|
||||
|
||||
def test_malformed_link_wrong_arity_does_not_crash(fake_nodes_module, TypeResolver):
|
||||
fake_nodes_module["Src"] = _v1_node(("IMAGE",))
|
||||
fake_nodes_module["Sink"] = _v1_node(("INT",), {"required": {"x": ("*",)}})
|
||||
prompt = {
|
||||
"src": {"class_type": "Src", "inputs": {}},
|
||||
"sink": {"class_type": "Sink", "inputs": {"x": ["src"]}}, # arity 1
|
||||
}
|
||||
r = TypeResolver(prompt)
|
||||
assert r.resolve_input_type("sink", "x") == "*"
|
||||
|
||||
|
||||
def test_direct_resolve_output_type_with_bad_slot_idx_returns_any(fake_nodes_module, TypeResolver):
|
||||
fake_nodes_module["Src"] = _v1_node(("IMAGE",))
|
||||
prompt = {"src": {"class_type": "Src", "inputs": {}}}
|
||||
r = TypeResolver(prompt)
|
||||
# type-wise these should be unreachable through normal validation but the
|
||||
# resolver must still degrade gracefully.
|
||||
assert r.resolve_output_type("src", "0") == "*"
|
||||
assert r.resolve_output_type("src", True) == "*" # bool is a subclass of int
|
||||
assert r.is_output_list("src", "0") is False
|
||||
|
||||
|
||||
def test_non_string_class_type_returns_any(fake_nodes_module, TypeResolver):
|
||||
prompt = {"n1": {"class_type": 42, "inputs": {}}}
|
||||
r = TypeResolver(prompt)
|
||||
assert r.resolve_output_type("n1", 0) == "*"
|
||||
|
||||
|
||||
def test_invalidate_node_only_clears_that_node(fake_nodes_module, TypeResolver):
|
||||
fake_nodes_module["SrcA"] = _v1_node(("IMAGE",))
|
||||
fake_nodes_module["SrcB"] = _v1_node(("LATENT",))
|
||||
|
||||
Reference in New Issue
Block a user