feat(isolation): singleton proxies for ComfyUI services

Adds proxy infrastructure for cross-process service access: BaseRegistry/
BaseProxy pattern, FolderPaths, ModelManagement, PromptServer, Progress,
Utils, HelperProxies, and WebDirectory proxies. These provide transparent
RPC access to host-side ComfyUI services from isolated child processes.
This commit is contained in:
John Pollock
2026-04-07 05:58:12 -05:00
parent 7d512fa9c3
commit 94720c0c02
9 changed files with 1480 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
from .base import (
IS_CHILD_PROCESS,
BaseProxy,
BaseRegistry,
detach_if_grad,
get_thread_loop,
run_coro_in_new_loop,
)
__all__ = [
"IS_CHILD_PROCESS",
"BaseRegistry",
"BaseProxy",
"get_thread_loop",
"run_coro_in_new_loop",
"detach_if_grad",
]

View File

@@ -0,0 +1,301 @@
# pylint: disable=global-statement,import-outside-toplevel,protected-access
from __future__ import annotations
import asyncio
import concurrent.futures
import logging
import os
import threading
import time
import weakref
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
try:
from pyisolate import ProxiedSingleton
except ImportError:
class ProxiedSingleton: # type: ignore[no-redef]
pass
logger = logging.getLogger(__name__)
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
_thread_local = threading.local()
T = TypeVar("T")
def get_thread_loop() -> asyncio.AbstractEventLoop:
loop = getattr(_thread_local, "loop", None)
if loop is None or loop.is_closed():
loop = asyncio.new_event_loop()
_thread_local.loop = loop
return loop
def run_coro_in_new_loop(coro: Any) -> Any:
result_box: Dict[str, Any] = {}
exc_box: Dict[str, BaseException] = {}
def runner() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result_box["value"] = loop.run_until_complete(coro)
except Exception as exc: # noqa: BLE001
exc_box["exc"] = exc
finally:
loop.close()
t = threading.Thread(target=runner, daemon=True)
t.start()
t.join()
if "exc" in exc_box:
raise exc_box["exc"]
return result_box.get("value")
def detach_if_grad(obj: Any) -> Any:
try:
import torch
except Exception:
return obj
if isinstance(obj, torch.Tensor):
return obj.detach() if obj.requires_grad else obj
if isinstance(obj, (list, tuple)):
return type(obj)(detach_if_grad(x) for x in obj)
if isinstance(obj, dict):
return {k: detach_if_grad(v) for k, v in obj.items()}
return obj
class BaseRegistry(ProxiedSingleton, Generic[T]):
_type_prefix: str = "base"
def __init__(self) -> None:
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
super().__init__()
self._registry: Dict[str, T] = {}
self._id_map: Dict[int, str] = {}
self._counter = 0
self._lock = threading.Lock()
def register(self, instance: T) -> str:
with self._lock:
obj_id = id(instance)
if obj_id in self._id_map:
return self._id_map[obj_id]
instance_id = f"{self._type_prefix}_{self._counter}"
self._counter += 1
self._registry[instance_id] = instance
self._id_map[obj_id] = instance_id
return instance_id
def unregister_sync(self, instance_id: str) -> None:
with self._lock:
instance = self._registry.pop(instance_id, None)
if instance:
self._id_map.pop(id(instance), None)
def _get_instance(self, instance_id: str) -> T:
if IS_CHILD_PROCESS:
raise RuntimeError(
f"[{self.__class__.__name__}] _get_instance called in child"
)
with self._lock:
instance = self._registry.get(instance_id)
if instance is None:
raise ValueError(f"{instance_id} not found")
return instance
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
global _GLOBAL_LOOP
_GLOBAL_LOOP = loop
def run_sync_rpc_coro(coro: Any, timeout_ms: Optional[int] = None) -> Any:
if timeout_ms is not None:
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
try:
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc
except concurrent.futures.TimeoutError as exc:
raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc
def call_singleton_rpc(
caller: Any,
method_name: str,
*args: Any,
timeout_ms: Optional[int] = None,
**kwargs: Any,
) -> Any:
if caller is None:
raise RuntimeError(f"No RPC caller available for {method_name}")
method = getattr(caller, method_name)
return run_sync_rpc_coro(method(*args, **kwargs), timeout_ms=timeout_ms)
class BaseProxy(Generic[T]):
_registry_class: type = BaseRegistry # type: ignore[type-arg]
__module__: str = "comfy.isolation.proxies.base"
_TIMEOUT_RPC_METHODS = frozenset(
{
"partially_load",
"partially_unload",
"load",
"patch_model",
"unpatch_model",
"inner_model_apply_model",
"memory_required",
"model_dtype",
"inner_model_memory_required",
"inner_model_extra_conds_shapes",
"inner_model_extra_conds",
"process_latent_in",
"process_latent_out",
"scale_latent_inpaint",
}
)
def __init__(
self,
instance_id: str,
registry: Optional[Any] = None,
manage_lifecycle: bool = False,
) -> None:
self._instance_id = instance_id
self._rpc_caller: Optional[Any] = None
self._registry = registry if registry is not None else self._registry_class()
self._manage_lifecycle = manage_lifecycle
self._cleaned_up = False
if manage_lifecycle and not IS_CHILD_PROCESS:
self._finalizer = weakref.finalize(
self, self._registry.unregister_sync, instance_id
)
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is None:
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
self._rpc_caller = rpc.create_caller(
self._registry_class, self._registry_class.get_remote_id()
)
return self._rpc_caller
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
if method_name not in self._TIMEOUT_RPC_METHODS:
return None
try:
timeout_ms = int(
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
)
except ValueError:
timeout_ms = 120000
return max(1, timeout_ms)
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
rpc = self._get_rpc()
method = getattr(rpc, method_name)
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
coro = method(self._instance_id, *args, **kwargs)
if timeout_ms is not None:
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
try:
running_loop = asyncio.get_running_loop()
loop_id: Optional[int] = id(running_loop)
except RuntimeError:
loop_id = None
logger.debug(
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
"thread=%s loop=%s timeout_ms=%s",
self.__class__.__name__,
method_name,
self._instance_id,
start_epoch,
thread_id,
loop_id,
timeout_ms,
)
try:
return run_sync_rpc_coro(coro, timeout_ms=timeout_ms)
except TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
finally:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
"elapsed_ms=%.3f thread=%s loop=%s",
self.__class__.__name__,
method_name,
self._instance_id,
end_epoch,
elapsed_ms,
thread_id,
loop_id,
)
def __getstate__(self) -> Dict[str, Any]:
return {"_instance_id": self._instance_id}
def __setstate__(self, state: Dict[str, Any]) -> None:
self._instance_id = state["_instance_id"]
self._rpc_caller = None
self._registry = self._registry_class()
self._manage_lifecycle = False
self._cleaned_up = False
def cleanup(self) -> None:
if self._cleaned_up or IS_CHILD_PROCESS:
return
self._cleaned_up = True
finalizer = getattr(self, "_finalizer", None)
if finalizer is not None:
finalizer.detach()
self._registry.unregister_sync(self._instance_id)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self._instance_id}>"
def create_rpc_method(method_name: str) -> Callable[..., Any]:
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
return self._call_rpc(method_name, *args, **kwargs)
method.__name__ = method_name
return method

View File

@@ -0,0 +1,221 @@
from __future__ import annotations
import logging
import os
import traceback
from typing import Any, Dict, Optional
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
_fp_logger = logging.getLogger(__name__)
def _folder_paths():
import folder_paths
return folder_paths
def _is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
def _serialize_folder_names_and_paths(data: dict[str, tuple[list[str], set[str]]]) -> dict[str, dict[str, list[str]]]:
return {
key: {"paths": list(paths), "extensions": sorted(list(extensions))}
for key, (paths, extensions) in data.items()
}
def _deserialize_folder_names_and_paths(data: dict[str, dict[str, list[str]]]) -> dict[str, tuple[list[str], set[str]]]:
return {
key: (list(value.get("paths", [])), set(value.get("extensions", [])))
for key, value in data.items()
}
class FolderPathsProxy(ProxiedSingleton):
"""
Dynamic proxy for folder_paths.
Uses __getattr__ for most lookups, with explicit handling for
mutable collections to ensure efficient by-value transfer.
"""
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("FolderPathsProxy RPC caller is not configured")
return cls._rpc
def __getattr__(self, name):
if _is_child_process():
property_rpc = {
"models_dir": "rpc_get_models_dir",
"folder_names_and_paths": "rpc_get_folder_names_and_paths",
"extension_mimetypes_cache": "rpc_get_extension_mimetypes_cache",
"filename_list_cache": "rpc_get_filename_list_cache",
}
rpc_name = property_rpc.get(name)
if rpc_name is not None:
return call_singleton_rpc(self._get_caller(), rpc_name)
raise AttributeError(name)
return getattr(_folder_paths(), name)
@property
def folder_names_and_paths(self) -> Dict:
if _is_child_process():
payload = call_singleton_rpc(self._get_caller(), "rpc_get_folder_names_and_paths")
return _deserialize_folder_names_and_paths(payload)
return _folder_paths().folder_names_and_paths
@property
def extension_mimetypes_cache(self) -> Dict:
if _is_child_process():
return dict(call_singleton_rpc(self._get_caller(), "rpc_get_extension_mimetypes_cache"))
return dict(_folder_paths().extension_mimetypes_cache)
@property
def filename_list_cache(self) -> Dict:
if _is_child_process():
return dict(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list_cache"))
return dict(_folder_paths().filename_list_cache)
@property
def models_dir(self) -> str:
if _is_child_process():
return str(call_singleton_rpc(self._get_caller(), "rpc_get_models_dir"))
return _folder_paths().models_dir
def get_temp_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_temp_directory")
return _folder_paths().get_temp_directory()
def get_input_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_input_directory")
return _folder_paths().get_input_directory()
def get_output_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_output_directory")
return _folder_paths().get_output_directory()
def get_user_directory(self) -> str:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_user_directory")
return _folder_paths().get_user_directory()
def get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str:
if _is_child_process():
return call_singleton_rpc(
self._get_caller(), "rpc_get_annotated_filepath", name, default_dir
)
return _folder_paths().get_annotated_filepath(name, default_dir)
def exists_annotated_filepath(self, name: str) -> bool:
if _is_child_process():
return bool(
call_singleton_rpc(self._get_caller(), "rpc_exists_annotated_filepath", name)
)
return bool(_folder_paths().exists_annotated_filepath(name))
def add_model_folder_path(
self, folder_name: str, full_folder_path: str, is_default: bool = False
) -> None:
if _is_child_process():
call_singleton_rpc(
self._get_caller(),
"rpc_add_model_folder_path",
folder_name,
full_folder_path,
is_default,
)
return None
_folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default)
return None
def get_folder_paths(self, folder_name: str) -> list[str]:
if _is_child_process():
return list(call_singleton_rpc(self._get_caller(), "rpc_get_folder_paths", folder_name))
return list(_folder_paths().get_folder_paths(folder_name))
def get_filename_list(self, folder_name: str) -> list[str]:
caller_stack = "".join(traceback.format_stack()[-4:-1])
_fp_logger.warning(
"][ DIAG:FolderPathsProxy.get_filename_list called | folder=%s | is_child=%s | rpc_configured=%s\n%s",
folder_name, _is_child_process(), self._rpc is not None, caller_stack,
)
if _is_child_process():
result = list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name))
_fp_logger.warning(
"][ DIAG:FolderPathsProxy.get_filename_list RPC result | folder=%s | count=%d | first=%s",
folder_name, len(result), result[:3] if result else "EMPTY",
)
return result
result = list(_folder_paths().get_filename_list(folder_name))
_fp_logger.warning(
"][ DIAG:FolderPathsProxy.get_filename_list LOCAL result | folder=%s | count=%d | first=%s",
folder_name, len(result), result[:3] if result else "EMPTY",
)
return result
def get_full_path(self, folder_name: str, filename: str) -> str | None:
if _is_child_process():
return call_singleton_rpc(self._get_caller(), "rpc_get_full_path", folder_name, filename)
return _folder_paths().get_full_path(folder_name, filename)
async def rpc_get_models_dir(self) -> str:
return _folder_paths().models_dir
async def rpc_get_folder_names_and_paths(self) -> dict[str, dict[str, list[str]]]:
return _serialize_folder_names_and_paths(_folder_paths().folder_names_and_paths)
async def rpc_get_extension_mimetypes_cache(self) -> dict[str, Any]:
return dict(_folder_paths().extension_mimetypes_cache)
async def rpc_get_filename_list_cache(self) -> dict[str, Any]:
return dict(_folder_paths().filename_list_cache)
async def rpc_get_temp_directory(self) -> str:
return _folder_paths().get_temp_directory()
async def rpc_get_input_directory(self) -> str:
return _folder_paths().get_input_directory()
async def rpc_get_output_directory(self) -> str:
return _folder_paths().get_output_directory()
async def rpc_get_user_directory(self) -> str:
return _folder_paths().get_user_directory()
async def rpc_get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str:
return _folder_paths().get_annotated_filepath(name, default_dir)
async def rpc_exists_annotated_filepath(self, name: str) -> bool:
return _folder_paths().exists_annotated_filepath(name)
async def rpc_add_model_folder_path(
self, folder_name: str, full_folder_path: str, is_default: bool = False
) -> None:
_folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default)
async def rpc_get_folder_paths(self, folder_name: str) -> list[str]:
return _folder_paths().get_folder_paths(folder_name)
async def rpc_get_filename_list(self, folder_name: str) -> list[str]:
return _folder_paths().get_filename_list(folder_name)
async def rpc_get_full_path(self, folder_name: str, filename: str) -> str | None:
return _folder_paths().get_full_path(folder_name, filename)

View File

@@ -0,0 +1,158 @@
from __future__ import annotations
import os
from typing import Any, Dict, Optional
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
class AnyTypeProxy(str):
"""Replacement for custom AnyType objects used by some nodes."""
def __new__(cls, value: str = "*"):
return super().__new__(cls, value)
def __ne__(self, other): # type: ignore[override]
return False
class FlexibleOptionalInputProxy(dict):
"""Replacement for FlexibleOptionalInputType to allow dynamic inputs."""
def __init__(self, flex_type, data: Optional[Dict[str, object]] = None):
super().__init__()
self.type = flex_type
if data:
self.update(data)
def __getitem__(self, key): # type: ignore[override]
return (self.type,)
def __contains__(self, key): # type: ignore[override]
return True
class ByPassTypeTupleProxy(tuple):
"""Replacement for ByPassTypeTuple to mirror wildcard fallback behavior."""
def __new__(cls, values):
return super().__new__(cls, values)
def __getitem__(self, index): # type: ignore[override]
if index >= len(self):
return AnyTypeProxy("*")
return super().__getitem__(index)
def _restore_special_value(value: Any) -> Any:
if isinstance(value, dict):
if value.get("__pyisolate_any_type__"):
return AnyTypeProxy(value.get("value", "*"))
if value.get("__pyisolate_flexible_optional__"):
flex_type = _restore_special_value(value.get("type"))
data_raw = value.get("data")
data = (
{k: _restore_special_value(v) for k, v in data_raw.items()}
if isinstance(data_raw, dict)
else {}
)
return FlexibleOptionalInputProxy(flex_type, data)
if value.get("__pyisolate_tuple__") is not None:
return tuple(
_restore_special_value(v) for v in value["__pyisolate_tuple__"]
)
if value.get("__pyisolate_bypass_tuple__") is not None:
return ByPassTypeTupleProxy(
tuple(
_restore_special_value(v)
for v in value["__pyisolate_bypass_tuple__"]
)
)
return {k: _restore_special_value(v) for k, v in value.items()}
if isinstance(value, list):
return [_restore_special_value(v) for v in value]
return value
def _serialize_special_value(value: Any) -> Any:
if isinstance(value, AnyTypeProxy):
return {"__pyisolate_any_type__": True, "value": str(value)}
if isinstance(value, FlexibleOptionalInputProxy):
return {
"__pyisolate_flexible_optional__": True,
"type": _serialize_special_value(value.type),
"data": {k: _serialize_special_value(v) for k, v in value.items()},
}
if isinstance(value, ByPassTypeTupleProxy):
return {
"__pyisolate_bypass_tuple__": [_serialize_special_value(v) for v in value]
}
if isinstance(value, tuple):
return {"__pyisolate_tuple__": [_serialize_special_value(v) for v in value]}
if isinstance(value, list):
return [_serialize_special_value(v) for v in value]
if isinstance(value, dict):
return {k: _serialize_special_value(v) for k, v in value.items()}
return value
def _restore_input_types_local(raw: Dict[str, object]) -> Dict[str, object]:
if not isinstance(raw, dict):
return raw # type: ignore[return-value]
restored: Dict[str, object] = {}
for section, entries in raw.items():
if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"):
restored[section] = _restore_special_value(entries)
elif isinstance(entries, dict):
restored[section] = {
k: _restore_special_value(v) for k, v in entries.items()
}
else:
restored[section] = _restore_special_value(entries)
return restored
class HelperProxiesService(ProxiedSingleton):
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("HelperProxiesService RPC caller is not configured")
return cls._rpc
async def rpc_restore_input_types(self, raw: Dict[str, object]) -> Dict[str, object]:
restored = _restore_input_types_local(raw)
return _serialize_special_value(restored)
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
if os.environ.get("PYISOLATE_CHILD") == "1":
payload = call_singleton_rpc(
HelperProxiesService._get_caller(),
"rpc_restore_input_types",
raw,
)
return _restore_input_types_local(payload)
return _restore_input_types_local(raw)
__all__ = [
"AnyTypeProxy",
"FlexibleOptionalInputProxy",
"ByPassTypeTupleProxy",
"HelperProxiesService",
"restore_input_types",
]

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
import os
from typing import Any, Optional
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
def _mm():
import comfy.model_management
return comfy.model_management
def _is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
class TorchDeviceProxy:
def __init__(self, device_str: str):
self._device_str = device_str
if ":" in device_str:
device_type, index = device_str.split(":", 1)
self.type = device_type
self.index = int(index)
else:
self.type = device_str
self.index = None
def __str__(self) -> str:
return self._device_str
def __repr__(self) -> str:
return f"TorchDeviceProxy({self._device_str!r})"
def _serialize_value(value: Any) -> Any:
value_type = type(value)
if value_type.__module__ == "torch" and value_type.__name__ == "device":
return {"__pyisolate_torch_device__": str(value)}
if isinstance(value, TorchDeviceProxy):
return {"__pyisolate_torch_device__": str(value)}
if isinstance(value, tuple):
return {"__pyisolate_tuple__": [_serialize_value(item) for item in value]}
if isinstance(value, list):
return [_serialize_value(item) for item in value]
if isinstance(value, dict):
return {key: _serialize_value(inner) for key, inner in value.items()}
return value
def _deserialize_value(value: Any) -> Any:
if isinstance(value, dict):
if "__pyisolate_torch_device__" in value:
return TorchDeviceProxy(value["__pyisolate_torch_device__"])
if "__pyisolate_tuple__" in value:
return tuple(_deserialize_value(item) for item in value["__pyisolate_tuple__"])
return {key: _deserialize_value(inner) for key, inner in value.items()}
if isinstance(value, list):
return [_deserialize_value(item) for item in value]
return value
def _normalize_argument(value: Any) -> Any:
if isinstance(value, TorchDeviceProxy):
import torch
return torch.device(str(value))
if isinstance(value, dict):
if "__pyisolate_torch_device__" in value:
import torch
return torch.device(value["__pyisolate_torch_device__"])
if "__pyisolate_tuple__" in value:
return tuple(_normalize_argument(item) for item in value["__pyisolate_tuple__"])
return {key: _normalize_argument(inner) for key, inner in value.items()}
if isinstance(value, list):
return [_normalize_argument(item) for item in value]
return value
class ModelManagementProxy(ProxiedSingleton):
"""
Exact-relay proxy for comfy.model_management.
Child calls never import comfy.model_management directly; they serialize
arguments, relay to host, and deserialize the host result back.
"""
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("ModelManagementProxy RPC caller is not configured")
return cls._rpc
def _relay_call(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
payload = call_singleton_rpc(
self._get_caller(),
"rpc_call",
method_name,
_serialize_value(args),
_serialize_value(kwargs),
)
return _deserialize_value(payload)
@property
def VRAMState(self):
return _mm().VRAMState
@property
def CPUState(self):
return _mm().CPUState
@property
def OOM_EXCEPTION(self):
return _mm().OOM_EXCEPTION
def __getattr__(self, name: str):
if _is_child_process():
def child_method(*args: Any, **kwargs: Any) -> Any:
return self._relay_call(name, *args, **kwargs)
return child_method
return getattr(_mm(), name)
async def rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any:
normalized_args = _normalize_argument(_deserialize_value(args))
normalized_kwargs = _normalize_argument(_deserialize_value(kwargs))
method = getattr(_mm(), method_name)
result = method(*normalized_args, **normalized_kwargs)
return _serialize_value(result)

View File

@@ -0,0 +1,87 @@
from __future__ import annotations
import logging
import os
from typing import Any, Optional
try:
from pyisolate import ProxiedSingleton
except ImportError:
class ProxiedSingleton:
pass
from .base import call_singleton_rpc
def _get_progress_state():
from comfy_execution.progress import get_progress_state
return get_progress_state()
def _is_child_process() -> bool:
return os.environ.get("PYISOLATE_CHILD") == "1"
logger = logging.getLogger(__name__)
class ProgressProxy(ProxiedSingleton):
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
@classmethod
def _get_caller(cls) -> Any:
if cls._rpc is None:
raise RuntimeError("ProgressProxy RPC caller is not configured")
return cls._rpc
def set_progress(
self,
value: float,
max_value: float,
node_id: Optional[str] = None,
image: Any = None,
) -> None:
if _is_child_process():
call_singleton_rpc(
self._get_caller(),
"rpc_set_progress",
value,
max_value,
node_id,
image,
)
return None
_get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,
image=image,
)
return None
async def rpc_set_progress(
self,
value: float,
max_value: float,
node_id: Optional[str] = None,
image: Any = None,
) -> None:
_get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,
image=image,
)
__all__ = ["ProgressProxy"]

View File

@@ -0,0 +1,271 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called
"""Stateless RPC Implementation for PromptServer.
Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture.
- Host: PromptServerService (RPC Handler)
- Child: PromptServerStub (Interface Implementation)
"""
from __future__ import annotations
import asyncio
import os
from typing import Any, Dict, Optional, Callable
import logging
# IMPORTS
from pyisolate import ProxiedSingleton
from .base import call_singleton_rpc
logger = logging.getLogger(__name__)
LOG_PREFIX = "[Isolation:C<->H]"
# ...
# =============================================================================
# CHILD SIDE: PromptServerStub
# =============================================================================
class PromptServerStub:
"""Stateless Stub for PromptServer."""
# Masquerade as the real server module
__module__ = "server"
_instance: Optional["PromptServerStub"] = None
_rpc: Optional[Any] = None # This will be the Caller object
_source_file: Optional[str] = None
def __init__(self):
self.routes = RouteStub(self)
@classmethod
def set_rpc(cls, rpc: Any) -> None:
"""Inject RPC client (called by adapter.py or manually)."""
# Create caller for HOST Service
# Assuming Host Service is registered as "PromptServerService" (class name)
# We target the Host Service Class
target_id = "PromptServerService"
# We need to pass a class to create_caller? Usually yes.
# But we don't have the Service class imported here necessarily (if running on child).
# pyisolate check verify_service type?
# If we pass PromptServerStub as the 'class', it might mismatch if checking types.
# But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub.
# We need a dummy class with right name?
# Or just rely on string ID if create_caller supports it?
# Standard: rpc.create_caller(PromptServerStub, target_id)
# But wait, PromptServerStub is the *Local* class.
# We want to call *Remote* class.
# If we use PromptServerStub as the type, returning object will be typed as PromptServerStub?
# The first arg is 'service_cls'.
cls._rpc = rpc.create_caller(
PromptServerService, target_id
) # We import Service below?
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
# We need PromptServerService available for the create_caller call?
# Or just use the Stub class if ID matches?
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
@property
def instance(self) -> "PromptServerStub":
return self
# ... Compatibility ...
@classmethod
def _get_source_file(cls) -> str:
if cls._source_file is None:
import folder_paths
cls._source_file = os.path.join(folder_paths.base_path, "server.py")
return cls._source_file
@property
def __file__(self) -> str:
return self._get_source_file()
# --- Properties ---
@property
def client_id(self) -> Optional[str]:
return "isolated_client"
def supports(self, feature: str) -> bool:
return True
@property
def app(self):
raise RuntimeError(
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
)
@property
def prompt_queue(self):
raise RuntimeError(
"PromptServer.prompt_queue is not accessible in isolated nodes."
)
# --- UI Communication (RPC Delegates) ---
async def send_sync(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
) -> None:
if self._rpc:
await self._rpc.ui_send_sync(event, data, sid)
async def send(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
) -> None:
if self._rpc:
await self._rpc.ui_send(event, data, sid)
def send_progress_text(self, text: str, node_id: str, sid=None) -> None:
if self._rpc:
# Fire and forget likely needed. If method is async on host, caller invocation returns coroutine.
# We must schedule it?
# Or use fire_remote equivalent?
# Caller object usually proxies calls. If host method is async, it returns coro.
# If we are sync here (send_progress_text checks imply sync usage), we must background it.
# But UtilsProxy hook wrapper creates task.
# Does send_progress_text need to be sync? Yes, node code calls it sync.
import asyncio
try:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
except RuntimeError:
call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid)
# --- Route Registration Logic ---
def register_route(self, method: str, path: str, handler: Callable):
"""Register a route handler via RPC."""
if not self._rpc:
logger.error("RPC not initialized in PromptServerStub")
return
# Fire registration async
try:
loop = asyncio.get_running_loop()
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
except RuntimeError:
call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler)
class RouteStub:
"""Simulates aiohttp.web.RouteTableDef."""
def __init__(self, stub: PromptServerStub):
self._stub = stub
def get(self, path: str):
def decorator(handler):
self._stub.register_route("GET", path, handler)
return handler
return decorator
def post(self, path: str):
def decorator(handler):
self._stub.register_route("POST", path, handler)
return handler
return decorator
def patch(self, path: str):
def decorator(handler):
self._stub.register_route("PATCH", path, handler)
return handler
return decorator
def put(self, path: str):
def decorator(handler):
self._stub.register_route("PUT", path, handler)
return handler
return decorator
def delete(self, path: str):
def decorator(handler):
self._stub.register_route("DELETE", path, handler)
return handler
return decorator
# =============================================================================
# HOST SIDE: PromptServerService
# =============================================================================
class PromptServerService(ProxiedSingleton):
"""Host-side RPC Service for PromptServer."""
def __init__(self):
# We will bind to the real server instance lazily or via global import
pass
@property
def server(self):
from server import PromptServer
return PromptServer.instance
async def ui_send_sync(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
):
await self.server.send_sync(event, data, sid)
async def ui_send(
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
):
await self.server.send(event, data, sid)
async def ui_send_progress_text(self, text: str, node_id: str, sid=None):
# Made async to be awaitable by RPC layer
self.server.send_progress_text(text, node_id, sid)
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
"""RPC Target: Register a route that forwards to the Child."""
from aiohttp import web
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
async def route_wrapper(request: web.Request) -> web.Response:
# 1. Capture request data
req_data = {
"method": request.method,
"path": request.path,
"query": dict(request.query),
}
if request.can_read_body:
req_data["text"] = await request.text()
try:
# 2. Call Child Handler via RPC (child_handler_proxy is async callable)
result = await child_handler_proxy(req_data)
# 3. Serialize Response
return self._serialize_response(result)
except Exception as e:
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
return web.Response(status=500, text=str(e))
# Register loop
self.server.app.router.add_route(method, path, route_wrapper)
def _serialize_response(self, result: Any) -> Any:
"""Helper to convert Child result -> web.Response"""
from aiohttp import web
if isinstance(result, web.Response):
return result
# Handle dict (json)
if isinstance(result, dict):
return web.json_response(result)
# Handle string
if isinstance(result, str):
return web.Response(text=result)
# Fallback
return web.Response(text=str(result))

View File

@@ -0,0 +1,64 @@
# pylint: disable=cyclic-import,import-outside-toplevel
from __future__ import annotations
from typing import Optional, Any
from pyisolate import ProxiedSingleton
import os
def _comfy_utils():
import comfy.utils
return comfy.utils
class UtilsProxy(ProxiedSingleton):
"""
Proxy for comfy.utils.
Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates
from isolated nodes reach the host.
"""
# _instance and __new__ removed to rely on SingletonMetaclass
_rpc: Optional[Any] = None
@classmethod
def set_rpc(cls, rpc: Any) -> None:
# Create caller using class name as ID (standard for Singletons)
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
@classmethod
def clear_rpc(cls) -> None:
cls._rpc = None
async def progress_bar_hook(
self,
value: int,
total: int,
preview: Optional[bytes] = None,
node_id: Optional[str] = None,
) -> Any:
"""
Host-side implementation: forwards the call to the real global hook.
Child-side: this method call is intercepted by RPC and sent to host.
"""
if os.environ.get("PYISOLATE_CHILD") == "1":
if UtilsProxy._rpc is None:
raise RuntimeError("UtilsProxy RPC caller is not configured")
return await UtilsProxy._rpc.progress_bar_hook(
value, total, preview, node_id
)
# Host Execution
utils = _comfy_utils()
if utils.PROGRESS_BAR_HOOK is not None:
return utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
return None
def set_progress_bar_global_hook(self, hook: Any) -> None:
"""Forward hook registration (though usually not needed from child)."""
if os.environ.get("PYISOLATE_CHILD") == "1":
raise RuntimeError(
"UtilsProxy.set_progress_bar_global_hook is not available in child without exact relay support"
)
_comfy_utils().set_progress_bar_global_hook(hook)

View File

@@ -0,0 +1,219 @@
"""WebDirectoryProxy — serves isolated node web assets via RPC.
Child side: enumerates and reads files from the extension's web/ directory.
Host side: gets an RPC proxy that fetches file listings and contents on demand.
Only files with allowed extensions (.js, .html, .css) are served.
Directory traversal is rejected. File contents are base64-encoded for
safe JSON-RPC transport.
"""
from __future__ import annotations
import base64
import logging
import os
from pathlib import Path, PurePosixPath
from typing import Any, Dict, List
from pyisolate import ProxiedSingleton
logger = logging.getLogger(__name__)
ALLOWED_EXTENSIONS = frozenset({".js", ".html", ".css"})
MIME_TYPES = {
".js": "application/javascript",
".html": "text/html",
".css": "text/css",
}
class WebDirectoryProxy(ProxiedSingleton):
"""Proxy for serving isolated extension web directories.
On the child side, this class has direct filesystem access to the
extension's web/ directory. On the host side, callers get an RPC
proxy whose method calls are forwarded to the child.
"""
# {extension_name: absolute_path_to_web_dir}
_web_dirs: dict[str, str] = {}
@classmethod
def register_web_dir(cls, extension_name: str, web_dir_path: str) -> None:
"""Register an extension's web directory (child-side only)."""
cls._web_dirs[extension_name] = web_dir_path
logger.info(
"][ WebDirectoryProxy: registered %s -> %s",
extension_name,
web_dir_path,
)
def list_web_files(self, extension_name: str) -> List[Dict[str, str]]:
"""Return a list of servable files in the extension's web directory.
Each entry is {"relative_path": "js/foo.js", "content_type": "application/javascript"}.
Only files with allowed extensions are included.
"""
web_dir = self._web_dirs.get(extension_name)
if not web_dir:
return []
root = Path(web_dir)
if not root.is_dir():
return []
result: List[Dict[str, str]] = []
for path in sorted(root.rglob("*")):
if not path.is_file():
continue
ext = path.suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
continue
rel = path.relative_to(root)
result.append({
"relative_path": str(PurePosixPath(rel)),
"content_type": MIME_TYPES[ext],
})
return result
def get_web_file(
self, extension_name: str, relative_path: str
) -> Dict[str, Any]:
"""Return the contents of a single web file as base64.
Raises ValueError for traversal attempts or disallowed file types.
Returns {"content": <base64 str>, "content_type": <MIME str>}.
"""
_validate_path(relative_path)
web_dir = self._web_dirs.get(extension_name)
if not web_dir:
raise FileNotFoundError(
f"No web directory registered for {extension_name}"
)
root = Path(web_dir)
target = (root / relative_path).resolve()
# Ensure resolved path is under the web directory
if not str(target).startswith(str(root.resolve())):
raise ValueError(f"Path escapes web directory: {relative_path}")
if not target.is_file():
raise FileNotFoundError(f"File not found: {relative_path}")
ext = target.suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise ValueError(f"Disallowed file type: {ext}")
content_type = MIME_TYPES[ext]
raw = target.read_bytes()
return {
"content": base64.b64encode(raw).decode("ascii"),
"content_type": content_type,
}
def _validate_path(relative_path: str) -> None:
"""Reject directory traversal and absolute paths."""
if os.path.isabs(relative_path):
raise ValueError(f"Absolute paths are not allowed: {relative_path}")
if ".." in PurePosixPath(relative_path).parts:
raise ValueError(f"Directory traversal is not allowed: {relative_path}")
# ---------------------------------------------------------------------------
# Host-side cache and aiohttp handler
# ---------------------------------------------------------------------------
class WebDirectoryCache:
"""Host-side in-memory cache for proxied web directory contents.
Populated lazily via RPC calls to the child's WebDirectoryProxy.
Once a file is cached, subsequent requests are served from memory.
"""
def __init__(self) -> None:
# {extension_name: {relative_path: {"content": bytes, "content_type": str}}}
self._file_cache: dict[str, dict[str, dict[str, Any]]] = {}
# {extension_name: [{"relative_path": str, "content_type": str}, ...]}
self._listing_cache: dict[str, list[dict[str, str]]] = {}
# {extension_name: WebDirectoryProxy (RPC proxy instance)}
self._proxies: dict[str, Any] = {}
def register_proxy(self, extension_name: str, proxy: Any) -> None:
"""Register an RPC proxy for an extension's web directory."""
self._proxies[extension_name] = proxy
logger.info(
"][ WebDirectoryCache: registered proxy for %s", extension_name
)
@property
def extension_names(self) -> list[str]:
return list(self._proxies.keys())
def list_files(self, extension_name: str) -> list[dict[str, str]]:
"""List servable files for an extension (cached after first call)."""
if extension_name not in self._listing_cache:
proxy = self._proxies.get(extension_name)
if proxy is None:
return []
try:
self._listing_cache[extension_name] = proxy.list_web_files(
extension_name
)
except Exception:
logger.warning(
"][ WebDirectoryCache: failed to list files for %s",
extension_name,
exc_info=True,
)
return []
return self._listing_cache[extension_name]
def get_file(
self, extension_name: str, relative_path: str
) -> dict[str, Any] | None:
"""Get file content (cached after first fetch). Returns None on miss."""
ext_cache = self._file_cache.get(extension_name)
if ext_cache and relative_path in ext_cache:
return ext_cache[relative_path]
proxy = self._proxies.get(extension_name)
if proxy is None:
return None
try:
result = proxy.get_web_file(extension_name, relative_path)
except (FileNotFoundError, ValueError):
return None
except Exception:
logger.warning(
"][ WebDirectoryCache: failed to fetch %s/%s",
extension_name,
relative_path,
exc_info=True,
)
return None
decoded = {
"content": base64.b64decode(result["content"]),
"content_type": result["content_type"],
}
if extension_name not in self._file_cache:
self._file_cache[extension_name] = {}
self._file_cache[extension_name][relative_path] = decoded
return decoded
# Global cache instance — populated during isolation loading
_web_directory_cache = WebDirectoryCache()
def get_web_directory_cache() -> WebDirectoryCache:
return _web_directory_cache