Fix ComfyAPI initialization for base async case, move on_load to be called before get_node_list

This commit is contained in:
Jedrzej Kosinski
2026-02-14 20:23:39 -08:00
parent 3caf8115b0
commit a9cc84bb78
3 changed files with 14 additions and 13 deletions

View File

@@ -22,14 +22,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
def __init__(self):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: 'node_replace.NodeReplace') -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)
node_replacement: NodeReplacement
class Execution(ProxiedSingleton):
async def set_progress(
self,
@@ -82,8 +85,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""

View File

@@ -16,7 +16,7 @@ async def register_replacements():
async def register_replacements_longeredge():
# No dynamic inputs here
await api.NodeReplacement().register(node_replace.NodeReplace(
await api.node_replacement.register(node_replace.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
@@ -31,7 +31,7 @@ async def register_replacements_longeredge():
async def register_replacements_batchimages():
# BatchImages node uses Autogrow
await api.NodeReplacement().register(node_replace.NodeReplace(
await api.node_replacement.register(node_replace.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
@@ -42,7 +42,7 @@ async def register_replacements_batchimages():
async def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
await api.NodeReplacement().register(node_replace.NodeReplace(
await api.node_replacement.register(node_replace.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
@@ -56,7 +56,7 @@ async def register_replacements_upscaleimage():
async def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
await api.NodeReplacement().register(node_replace.NodeReplace(
await api.node_replacement.register(node_replace.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
@@ -66,28 +66,28 @@ async def register_replacements_controlnet():
async def register_replacements_load3d():
# Load3DAnimation merged into Load3D
await api.NodeReplacement().register(node_replace.NodeReplace(
await api.node_replacement.register(node_replace.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))
async def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
await api.NodeReplacement().register(node_replace.NodeReplace(
await api.node_replacement.register(node_replace.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))
async def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
await api.NodeReplacement().register(node_replace.NodeReplace(
await api.node_replacement.register(node_replace.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))
async def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
await api.NodeReplacement().register(node_replace.NodeReplace(
await api.node_replacement.register(node_replace.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))

View File

@@ -2264,6 +2264,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if not isinstance(extension, ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False
await extension.on_load()
node_list = await extension.get_node_list()
if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
@@ -2276,7 +2277,6 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if schema.display_name is not None:
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
await extension.on_load()
return True
except Exception as e:
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")