refactor: process isolation support for node replacement API (#12298)

* refactor: process isolation support for node replacement API

- Move REGISTERED_NODE_REPLACEMENTS global to NodeReplaceManager instance state
- Add NodeReplacement class to ComfyAPI_latest with async register() method
- Deprecate module-level register_node_replacement() function
- Call register_replacements() from comfy_entrypoint()

This enables pyisolate compatibility where extensions run in separate
processes and communicate via RPC. The async API allows registration
calls to cross process boundaries.

Refs: TDD-002
Amp-Thread-ID: https://ampcode.com/threads/T-019c2b33-ac55-76a9-9c6b-0246a8625f21

* fix: remove whitespace and deprecation cruft

Amp-Thread-ID: https://ampcode.com/threads/T-019c2be8-0b34-747e-b1f7-20a1a1e6c9df
This commit is contained in:
Christian Byrne
2026-02-05 12:21:03 -08:00
committed by GitHub
parent d5b3da823d
commit a2d4c0f98b
4 changed files with 65 additions and 55 deletions

View File

@@ -6,18 +6,33 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy_api.latest._node_replace import NodeReplace
REGISTERED_NODE_REPLACEMENTS: dict[str, list[NodeReplace]] = {}
def register_node_replacement(node_replace: NodeReplace):
REGISTERED_NODE_REPLACEMENTS.setdefault(node_replace.old_node_id, []).append(node_replace)
def registered_as_dict():
return {
k: [v.as_dict() for v in v_list] for k, v_list in REGISTERED_NODE_REPLACEMENTS.items()
}
class NodeReplaceManager:
"""Manages node replacement registrations."""
def __init__(self):
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
return self._replacements.get(old_node_id)
def has_replacement(self, old_node_id: str) -> bool:
"""Check if a replacement exists for an old node ID."""
return old_node_id in self._replacements
def as_dict(self):
"""Serialize all replacements to dict."""
return {
k: [v.as_dict() for v in v_list]
for k, v_list in self._replacements.items()
}
def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(registered_as_dict())
return web.json_response(self.as_dict())

View File

@@ -22,6 +22,14 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: 'node_replace.NodeReplace') -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)
node_replacement: NodeReplacement
class Execution(ProxiedSingleton):
async def set_progress(
self,

View File

@@ -1,13 +1,6 @@
from __future__ import annotations
from typing import Any
import app.node_replace_manager
def register_node_replacement(node_replace: NodeReplace):
"""
Register node replacement.
"""
app.node_replace_manager.register_node_replacement(node_replace)
class NodeReplace:
@@ -30,9 +23,7 @@ class NodeReplace:
self.output_mapping = output_mapping
def as_dict(self):
"""
Create serializable representation of the node replacement.
"""
"""Create serializable representation of the node replacement."""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
@@ -58,9 +49,7 @@ class InputMap:
}
class OldId(_Assign):
"""
Connect the input of the old node with given id to new node when replacing.
"""
"""Connect the input of the old node with given id to new node when replacing."""
def __init__(self, old_id: str):
super().__init__("old_id")
self.old_id = old_id
@@ -71,9 +60,7 @@ class InputMap:
}
class SetValue(_Assign):
"""
Use the given value for the input of the new node when replacing; assumes input is a widget.
"""
"""Use the given value for the input of the new node when replacing; assumes input is a widget."""
def __init__(self, value: Any):
super().__init__("set_value")
self.value = value
@@ -95,9 +82,7 @@ class InputMap:
class OutputMap:
"""
Map outputs of node replacement via indexes, as that's how outputs are stored.
"""
"""Map outputs of node replacement via indexes, as that's how outputs are stored."""
def __init__(self, new_idx: int, old_idx: int):
self.new_idx = new_idx
self.old_idx = old_idx

View File

@@ -656,21 +656,22 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
return io.NodeOutput(batched)
from comfy_api.latest import node_replace
from comfy_api.latest import ComfyAPI, node_replace
def register_replacements():
register_replacements_longeredge()
register_replacements_batchimages()
register_replacements_upscaleimage()
register_replacements_controlnet()
register_replacements_load3d()
register_replacements_preview3d()
register_replacements_svdimg2vid()
register_replacements_conditioningavg()
async def register_replacements():
"""Register all built-in node replacements using the async API."""
await register_replacements_longeredge()
await register_replacements_batchimages()
await register_replacements_upscaleimage()
await register_replacements_controlnet()
await register_replacements_load3d()
await register_replacements_preview3d()
await register_replacements_svdimg2vid()
await register_replacements_conditioningavg()
def register_replacements_longeredge():
async def register_replacements_longeredge():
# No dynamic inputs here
node_replace.register_node_replacement(node_replace.NodeReplace(
await ComfyAPI.node_replacement.register(node_replace.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
@@ -683,9 +684,9 @@ def register_replacements_longeredge():
output_mapping=[node_replace.OutputMap(new_idx=0, old_idx=0)],
))
def register_replacements_batchimages():
async def register_replacements_batchimages():
# BatchImages node uses Autogrow
node_replace.register_node_replacement(node_replace.NodeReplace(
await ComfyAPI.node_replacement.register(node_replace.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
@@ -694,9 +695,9 @@ def register_replacements_batchimages():
],
))
def register_replacements_upscaleimage():
async def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
node_replace.register_node_replacement(node_replace.NodeReplace(
await ComfyAPI.node_replacement.register(node_replace.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
@@ -708,9 +709,9 @@ def register_replacements_upscaleimage():
],
))
def register_replacements_controlnet():
async def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
node_replace.register_node_replacement(node_replace.NodeReplace(
await ComfyAPI.node_replacement.register(node_replace.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
@@ -718,30 +719,30 @@ def register_replacements_controlnet():
],
))
def register_replacements_load3d():
async def register_replacements_load3d():
# Load3DAnimation merged into Load3D
node_replace.register_node_replacement(node_replace.NodeReplace(
await ComfyAPI.node_replacement.register(node_replace.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))
def register_replacements_preview3d():
async def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
node_replace.register_node_replacement(node_replace.NodeReplace(
await ComfyAPI.node_replacement.register(node_replace.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))
def register_replacements_svdimg2vid():
async def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
node_replace.register_node_replacement(node_replace.NodeReplace(
await ComfyAPI.node_replacement.register(node_replace.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))
def register_replacements_conditioningavg():
async def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
node_replace.register_node_replacement(node_replace.NodeReplace(
await ComfyAPI.node_replacement.register(node_replace.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))
@@ -763,4 +764,5 @@ class PostProcessingExtension(ComfyExtension):
]
async def comfy_entrypoint() -> PostProcessingExtension:
await register_replacements()
return PostProcessingExtension()