From 07d1ee2ca9b1f33917d36c42192ffb91326170ff Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 14 Feb 2026 19:28:02 -0800 Subject: [PATCH] Add a public API function for registering node replacements, refactor code accordingly --- app/node_replace_manager.py | 6 +++--- comfy_api/latest/_node_replace.py | 6 ++++++ comfy_extras/nodes_replacements.py | 20 ++++++++------------ 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py index 7b8043d36..e805dee2f 100644 --- a/app/node_replace_manager.py +++ b/app/node_replace_manager.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from comfy_api.latest._node_replace import NodeReplace from comfy_execution.graph_utils import is_link -from nodes import NODE_CLASS_MAPPINGS +import nodes class NodeStruct(TypedDict): inputs: dict[str, str | int | float | bool | tuple[str, int]] @@ -48,7 +48,7 @@ class NodeReplaceManager: for node_number, node_struct in prompt.items(): class_type = node_struct["class_type"] # need replacement if not in NODE_CLASS_MAPPINGS and has replacement - if class_type not in NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type): + if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type): need_replacement.add(node_number) # keep track of connections for input_id, input_value in node_struct["inputs"].items(): @@ -65,7 +65,7 @@ class NodeReplaceManager: replacement = replacements[0] new_node_id = replacement.new_node_id # if replacement is not a valid node, skip trying to replace it as will only cause confusion - if new_node_id not in NODE_CLASS_MAPPINGS.keys(): + if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys(): continue # first, replace node id (class_type) new_node_struct = copy_node_struct(node_struct, empty_inputs=True) diff --git a/comfy_api/latest/_node_replace.py b/comfy_api/latest/_node_replace.py index c87b487c0..a55b2b733 100644 --- a/comfy_api/latest/_node_replace.py +++ b/comfy_api/latest/_node_replace.py @@ -1,6 +1,12 @@ from __future__ import annotations from typing import Any, TypedDict +from server import PromptServer + + +def register(node_replace: NodeReplace): + """Register a node replacement mapping.""" + PromptServer.instance.node_replace_manager.register(node_replace) class InputMapOldId(TypedDict): diff --git a/comfy_extras/nodes_replacements.py b/comfy_extras/nodes_replacements.py index de0ba6db0..79e03f55f 100644 --- a/comfy_extras/nodes_replacements.py +++ b/comfy_extras/nodes_replacements.py @@ -1,9 +1,5 @@ from comfy_api.latest import ComfyExtension, io, node_replace -from server import PromptServer -def _register(nr: node_replace.NodeReplace): - """Helper to register replacements via PromptServer.""" - PromptServer.instance.node_replace_manager.register(nr) async def register_replacements(): """Register all built-in node replacements.""" @@ -18,7 +14,7 @@ async def register_replacements(): def register_replacements_longeredge(): # No dynamic inputs here - _register(node_replace.NodeReplace( + node_replace.register(node_replace.NodeReplace( new_node_id="ImageScaleToMaxDimension", old_node_id="ResizeImagesByLongerEdge", old_widget_ids=["longer_edge"], @@ -33,7 +29,7 @@ def register_replacements_longeredge(): def register_replacements_batchimages(): # BatchImages node uses Autogrow - _register(node_replace.NodeReplace( + node_replace.register(node_replace.NodeReplace( new_node_id="BatchImagesNode", old_node_id="ImageBatch", input_mapping=[ @@ -44,7 +40,7 @@ def register_replacements_batchimages(): def register_replacements_upscaleimage(): # ResizeImageMaskNode uses DynamicCombo - _register(node_replace.NodeReplace( + node_replace.register(node_replace.NodeReplace( new_node_id="ResizeImageMaskNode", old_node_id="ImageScaleBy", old_widget_ids=["upscale_method", "scale_by"], @@ -58,7 +54,7 @@ def register_replacements_upscaleimage(): def register_replacements_controlnet(): # T2IAdapterLoader → ControlNetLoader - _register(node_replace.NodeReplace( + node_replace.register(node_replace.NodeReplace( new_node_id="ControlNetLoader", old_node_id="T2IAdapterLoader", input_mapping=[ @@ -68,28 +64,28 @@ def register_replacements_controlnet(): def register_replacements_load3d(): # Load3DAnimation merged into Load3D - _register(node_replace.NodeReplace( + node_replace.register(node_replace.NodeReplace( new_node_id="Load3D", old_node_id="Load3DAnimation", )) def register_replacements_preview3d(): # Preview3DAnimation merged into Preview3D - _register(node_replace.NodeReplace( + node_replace.register(node_replace.NodeReplace( new_node_id="Preview3D", old_node_id="Preview3DAnimation", )) def register_replacements_svdimg2vid(): # Typo fix: SDV → SVD - _register(node_replace.NodeReplace( + node_replace.register(node_replace.NodeReplace( new_node_id="SVD_img2vid_Conditioning", old_node_id="SDV_img2vid_Conditioning", )) def register_replacements_conditioningavg(): # Typo fix: trailing space in node name - _register(node_replace.NodeReplace( + node_replace.register(node_replace.NodeReplace( new_node_id="ConditioningAverage", old_node_id="ConditioningAverage ", ))