diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py new file mode 100644 index 000000000..3b1b7ab36 --- /dev/null +++ b/app/node_replace_manager.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from aiohttp import web + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy_api.latest._node_replace import NodeReplace + +REGISTERED_NODE_REPLACEMENTS: dict[str, list[NodeReplace]] = {} + +def register_node_replacement(node_replace: NodeReplace): + REGISTERED_NODE_REPLACEMENTS.setdefault(node_replace.old_node_id, []).append(node_replace) + +def registered_as_dict(): + return { + k: [v.as_dict() for v in v_list] for k, v_list in REGISTERED_NODE_REPLACEMENTS.items() + } + +class NodeReplaceManager: + def add_routes(self, routes): + @routes.get("/node_replacements") + async def get_node_replacements(request): + return web.json_response(registered_as_dict()) diff --git a/comfy_api/latest/_node_replace.py b/comfy_api/latest/_node_replace.py index 4703937d9..b8278d09d 100644 --- a/comfy_api/latest/_node_replace.py +++ b/comfy_api/latest/_node_replace.py @@ -1,6 +1,13 @@ from __future__ import annotations from typing import Any +import app.node_replace_manager + +def register_node_replacement(node_replace: NodeReplace): + """ + Register node replacement. + """ + app.node_replace_manager.register_node_replacement(node_replace) class NodeReplace: @@ -12,8 +19,8 @@ class NodeReplace: def __init__(self, new_node_id: str, old_node_id: str, - input_mapping: list[InputMap], - output_mapping: list[OutputMap], + input_mapping: list[InputMap] | None=None, + output_mapping: list[OutputMap] | None=None, ): self.new_node_id = new_node_id self.old_node_id = old_node_id @@ -27,8 +34,8 @@ class NodeReplace: return { "new_node_id": self.new_node_id, "old_node_id": self.old_node_id, - "input_mapping": [m.as_dict() for m in self.input_mapping], - "output_mapping": [m.as_dict() for m in self.output_mapping], + "input_mapping": [m.as_dict() for m in self.input_mapping] if self.input_mapping else None, + "output_mapping": [m.as_dict() for m in self.output_mapping] if self.output_mapping else None, } diff --git a/server.py b/server.py index 1888745b7..8da28912a 100644 --- a/server.py +++ b/server.py @@ -40,6 +40,7 @@ from app.user_manager import UserManager from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from app.subgraph_manager import SubgraphManager +from app.node_replace_manager import NodeReplaceManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -204,6 +205,7 @@ class PromptServer(): self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() self.subgraph_manager = SubgraphManager() + self.node_replace_manager = NodeReplaceManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] self.prompt_queue = execution.PromptQueue(self) @@ -992,6 +994,7 @@ class PromptServer(): self.model_file_manager.add_routes(self.routes) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items()) + self.node_replace_manager.add_routes(self.routes) self.app.add_subapp('/internal', self.internal_routes.get_app()) # Prefix every route with /api for easier matching for delegation.