mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-29 02:41:27 +00:00
Add a public API function for registering node replacements, refactor code accordingly
This commit is contained in:
@@ -7,7 +7,7 @@ if TYPE_CHECKING:
|
|||||||
from comfy_api.latest._node_replace import NodeReplace
|
from comfy_api.latest._node_replace import NodeReplace
|
||||||
|
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
from nodes import NODE_CLASS_MAPPINGS
|
import nodes
|
||||||
|
|
||||||
class NodeStruct(TypedDict):
|
class NodeStruct(TypedDict):
|
||||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||||
@@ -48,7 +48,7 @@ class NodeReplaceManager:
|
|||||||
for node_number, node_struct in prompt.items():
|
for node_number, node_struct in prompt.items():
|
||||||
class_type = node_struct["class_type"]
|
class_type = node_struct["class_type"]
|
||||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||||
if class_type not in NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||||
need_replacement.add(node_number)
|
need_replacement.add(node_number)
|
||||||
# keep track of connections
|
# keep track of connections
|
||||||
for input_id, input_value in node_struct["inputs"].items():
|
for input_id, input_value in node_struct["inputs"].items():
|
||||||
@@ -65,7 +65,7 @@ class NodeReplaceManager:
|
|||||||
replacement = replacements[0]
|
replacement = replacements[0]
|
||||||
new_node_id = replacement.new_node_id
|
new_node_id = replacement.new_node_id
|
||||||
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||||
if new_node_id not in NODE_CLASS_MAPPINGS.keys():
|
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||||
continue
|
continue
|
||||||
# first, replace node id (class_type)
|
# first, replace node id (class_type)
|
||||||
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||||
|
|||||||
@@ -1,6 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
from server import PromptServer
|
||||||
|
|
||||||
|
|
||||||
|
def register(node_replace: NodeReplace):
|
||||||
|
"""Register a node replacement mapping."""
|
||||||
|
PromptServer.instance.node_replace_manager.register(node_replace)
|
||||||
|
|
||||||
|
|
||||||
class InputMapOldId(TypedDict):
|
class InputMapOldId(TypedDict):
|
||||||
|
|||||||
@@ -1,9 +1,5 @@
|
|||||||
from comfy_api.latest import ComfyExtension, io, node_replace
|
from comfy_api.latest import ComfyExtension, io, node_replace
|
||||||
from server import PromptServer
|
|
||||||
|
|
||||||
def _register(nr: node_replace.NodeReplace):
|
|
||||||
"""Helper to register replacements via PromptServer."""
|
|
||||||
PromptServer.instance.node_replace_manager.register(nr)
|
|
||||||
|
|
||||||
async def register_replacements():
|
async def register_replacements():
|
||||||
"""Register all built-in node replacements."""
|
"""Register all built-in node replacements."""
|
||||||
@@ -18,7 +14,7 @@ async def register_replacements():
|
|||||||
|
|
||||||
def register_replacements_longeredge():
|
def register_replacements_longeredge():
|
||||||
# No dynamic inputs here
|
# No dynamic inputs here
|
||||||
_register(node_replace.NodeReplace(
|
node_replace.register(node_replace.NodeReplace(
|
||||||
new_node_id="ImageScaleToMaxDimension",
|
new_node_id="ImageScaleToMaxDimension",
|
||||||
old_node_id="ResizeImagesByLongerEdge",
|
old_node_id="ResizeImagesByLongerEdge",
|
||||||
old_widget_ids=["longer_edge"],
|
old_widget_ids=["longer_edge"],
|
||||||
@@ -33,7 +29,7 @@ def register_replacements_longeredge():
|
|||||||
|
|
||||||
def register_replacements_batchimages():
|
def register_replacements_batchimages():
|
||||||
# BatchImages node uses Autogrow
|
# BatchImages node uses Autogrow
|
||||||
_register(node_replace.NodeReplace(
|
node_replace.register(node_replace.NodeReplace(
|
||||||
new_node_id="BatchImagesNode",
|
new_node_id="BatchImagesNode",
|
||||||
old_node_id="ImageBatch",
|
old_node_id="ImageBatch",
|
||||||
input_mapping=[
|
input_mapping=[
|
||||||
@@ -44,7 +40,7 @@ def register_replacements_batchimages():
|
|||||||
|
|
||||||
def register_replacements_upscaleimage():
|
def register_replacements_upscaleimage():
|
||||||
# ResizeImageMaskNode uses DynamicCombo
|
# ResizeImageMaskNode uses DynamicCombo
|
||||||
_register(node_replace.NodeReplace(
|
node_replace.register(node_replace.NodeReplace(
|
||||||
new_node_id="ResizeImageMaskNode",
|
new_node_id="ResizeImageMaskNode",
|
||||||
old_node_id="ImageScaleBy",
|
old_node_id="ImageScaleBy",
|
||||||
old_widget_ids=["upscale_method", "scale_by"],
|
old_widget_ids=["upscale_method", "scale_by"],
|
||||||
@@ -58,7 +54,7 @@ def register_replacements_upscaleimage():
|
|||||||
|
|
||||||
def register_replacements_controlnet():
|
def register_replacements_controlnet():
|
||||||
# T2IAdapterLoader → ControlNetLoader
|
# T2IAdapterLoader → ControlNetLoader
|
||||||
_register(node_replace.NodeReplace(
|
node_replace.register(node_replace.NodeReplace(
|
||||||
new_node_id="ControlNetLoader",
|
new_node_id="ControlNetLoader",
|
||||||
old_node_id="T2IAdapterLoader",
|
old_node_id="T2IAdapterLoader",
|
||||||
input_mapping=[
|
input_mapping=[
|
||||||
@@ -68,28 +64,28 @@ def register_replacements_controlnet():
|
|||||||
|
|
||||||
def register_replacements_load3d():
|
def register_replacements_load3d():
|
||||||
# Load3DAnimation merged into Load3D
|
# Load3DAnimation merged into Load3D
|
||||||
_register(node_replace.NodeReplace(
|
node_replace.register(node_replace.NodeReplace(
|
||||||
new_node_id="Load3D",
|
new_node_id="Load3D",
|
||||||
old_node_id="Load3DAnimation",
|
old_node_id="Load3DAnimation",
|
||||||
))
|
))
|
||||||
|
|
||||||
def register_replacements_preview3d():
|
def register_replacements_preview3d():
|
||||||
# Preview3DAnimation merged into Preview3D
|
# Preview3DAnimation merged into Preview3D
|
||||||
_register(node_replace.NodeReplace(
|
node_replace.register(node_replace.NodeReplace(
|
||||||
new_node_id="Preview3D",
|
new_node_id="Preview3D",
|
||||||
old_node_id="Preview3DAnimation",
|
old_node_id="Preview3DAnimation",
|
||||||
))
|
))
|
||||||
|
|
||||||
def register_replacements_svdimg2vid():
|
def register_replacements_svdimg2vid():
|
||||||
# Typo fix: SDV → SVD
|
# Typo fix: SDV → SVD
|
||||||
_register(node_replace.NodeReplace(
|
node_replace.register(node_replace.NodeReplace(
|
||||||
new_node_id="SVD_img2vid_Conditioning",
|
new_node_id="SVD_img2vid_Conditioning",
|
||||||
old_node_id="SDV_img2vid_Conditioning",
|
old_node_id="SDV_img2vid_Conditioning",
|
||||||
))
|
))
|
||||||
|
|
||||||
def register_replacements_conditioningavg():
|
def register_replacements_conditioningavg():
|
||||||
# Typo fix: trailing space in node name
|
# Typo fix: trailing space in node name
|
||||||
_register(node_replace.NodeReplace(
|
node_replace.register(node_replace.NodeReplace(
|
||||||
new_node_id="ConditioningAverage",
|
new_node_id="ConditioningAverage",
|
||||||
old_node_id="ConditioningAverage ",
|
old_node_id="ConditioningAverage ",
|
||||||
))
|
))
|
||||||
|
|||||||
Reference in New Issue
Block a user