Add a public API function for registering node replacements, refactor code accordingly

This commit is contained in:
Jedrzej Kosinski
2026-02-14 19:28:02 -08:00
parent 1ab9d4a9db
commit 07d1ee2ca9
3 changed files with 17 additions and 15 deletions

View File

@@ -7,7 +7,7 @@ if TYPE_CHECKING:
from comfy_api.latest._node_replace import NodeReplace
from comfy_execution.graph_utils import is_link
from nodes import NODE_CLASS_MAPPINGS
import nodes
class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
@@ -48,7 +48,7 @@ class NodeReplaceManager:
for node_number, node_struct in prompt.items():
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
need_replacement.add(node_number)
# keep track of connections
for input_id, input_value in node_struct["inputs"].items():
@@ -65,7 +65,7 @@ class NodeReplaceManager:
replacement = replacements[0]
new_node_id = replacement.new_node_id
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
if new_node_id not in NODE_CLASS_MAPPINGS.keys():
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
continue
# first, replace node id (class_type)
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)

View File

@@ -1,6 +1,12 @@
from __future__ import annotations
from typing import Any, TypedDict
from server import PromptServer
def register(node_replace: NodeReplace):
"""Register a node replacement mapping."""
PromptServer.instance.node_replace_manager.register(node_replace)
class InputMapOldId(TypedDict):

View File

@@ -1,9 +1,5 @@
from comfy_api.latest import ComfyExtension, io, node_replace
from server import PromptServer
def _register(nr: node_replace.NodeReplace):
"""Helper to register replacements via PromptServer."""
PromptServer.instance.node_replace_manager.register(nr)
async def register_replacements():
"""Register all built-in node replacements."""
@@ -18,7 +14,7 @@ async def register_replacements():
def register_replacements_longeredge():
# No dynamic inputs here
_register(node_replace.NodeReplace(
node_replace.register(node_replace.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
@@ -33,7 +29,7 @@ def register_replacements_longeredge():
def register_replacements_batchimages():
# BatchImages node uses Autogrow
_register(node_replace.NodeReplace(
node_replace.register(node_replace.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
@@ -44,7 +40,7 @@ def register_replacements_batchimages():
def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
_register(node_replace.NodeReplace(
node_replace.register(node_replace.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
@@ -58,7 +54,7 @@ def register_replacements_upscaleimage():
def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
_register(node_replace.NodeReplace(
node_replace.register(node_replace.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
@@ -68,28 +64,28 @@ def register_replacements_controlnet():
def register_replacements_load3d():
# Load3DAnimation merged into Load3D
_register(node_replace.NodeReplace(
node_replace.register(node_replace.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))
def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
_register(node_replace.NodeReplace(
node_replace.register(node_replace.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))
def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
_register(node_replace.NodeReplace(
node_replace.register(node_replace.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))
def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
_register(node_replace.NodeReplace(
node_replace.register(node_replace.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))