diff --git a/comfy/isolation/__init__.py b/comfy/isolation/__init__.py new file mode 100644 index 000000000..c72a92807 --- /dev/null +++ b/comfy/isolation/__init__.py @@ -0,0 +1,327 @@ +# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation +from __future__ import annotations +import asyncio +import inspect +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Set, TYPE_CHECKING +import folder_paths +from .extension_loader import load_isolated_node +from .manifest_loader import find_manifest_directories +from .runtime_helpers import build_stub_class, get_class_types_for_extension +from .shm_forensics import scan_shm_forensics, start_shm_forensics + +if TYPE_CHECKING: + from pyisolate import ExtensionManager + from .extension_wrapper import ComfyNodeExtension + +LOG_PREFIX = "][" +isolated_node_timings: List[tuple[float, Path, int]] = [] + +PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs" +PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True) + +logger = logging.getLogger(__name__) +_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + + +def initialize_proxies() -> None: + from .child_hooks import is_child_process + + is_child = is_child_process() + + if is_child: + from .child_hooks import initialize_child_process + + initialize_child_process() + else: + from .host_hooks import initialize_host_process + + initialize_host_process() + start_shm_forensics() + + +@dataclass(frozen=True) +class IsolatedNodeSpec: + node_name: str + display_name: str + stub_class: type + module_path: Path + + +_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = [] +_CLAIMED_PATHS: Set[Path] = set() +_ISOLATION_SCAN_ATTEMPTED = False +_EXTENSION_MANAGERS: List["ExtensionManager"] = [] +_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {} +_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None +_EARLY_START_TIME: Optional[float] = None + + +def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None: + global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME + if _ISOLATION_BACKGROUND_TASK is not None: + return + _EARLY_START_TIME = time.perf_counter() + _ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes()) + + +async def await_isolation_loading() -> List[IsolatedNodeSpec]: + global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME + if _ISOLATION_BACKGROUND_TASK is not None: + specs = await _ISOLATION_BACKGROUND_TASK + return specs + return await initialize_isolation_nodes() + + +async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]: + global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS + + if _ISOLATED_NODE_SPECS: + return _ISOLATED_NODE_SPECS + + if _ISOLATION_SCAN_ATTEMPTED: + return [] + + _ISOLATION_SCAN_ATTEMPTED = True + manifest_entries = find_manifest_directories() + _CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries} + + if not manifest_entries: + return [] + + os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1" + concurrency_limit = max(1, (os.cpu_count() or 4) // 2) + semaphore = asyncio.Semaphore(concurrency_limit) + + async def load_with_semaphore( + node_dir: Path, manifest: Path + ) -> List[IsolatedNodeSpec]: + async with semaphore: + load_start = time.perf_counter() + spec_list = await load_isolated_node( + node_dir, + manifest, + logger, + lambda name, info, extension: build_stub_class( + name, + info, + extension, + _RUNNING_EXTENSIONS, + logger, + ), + PYISOLATE_VENV_ROOT, + _EXTENSION_MANAGERS, + ) + spec_list = [ + IsolatedNodeSpec( + node_name=node_name, + display_name=display_name, + stub_class=stub_cls, + module_path=node_dir, + ) + for node_name, display_name, stub_cls in spec_list + ] + isolated_node_timings.append( + (time.perf_counter() - load_start, node_dir, len(spec_list)) + ) + return spec_list + + tasks = [ + load_with_semaphore(node_dir, manifest) + for node_dir, manifest in manifest_entries + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + specs: List[IsolatedNodeSpec] = [] + for result in results: + if isinstance(result, Exception): + logger.error( + "%s Isolated node failed during startup; continuing: %s", + LOG_PREFIX, + result, + ) + continue + specs.extend(result) + + _ISOLATED_NODE_SPECS = specs + return list(_ISOLATED_NODE_SPECS) + + +def _get_class_types_for_extension(extension_name: str) -> Set[str]: + """Get all node class types (node names) belonging to an extension.""" + extension = _RUNNING_EXTENSIONS.get(extension_name) + if not extension: + return set() + + ext_path = Path(extension.module_path) + class_types = set() + for spec in _ISOLATED_NODE_SPECS: + if spec.module_path.resolve() == ext_path.resolve(): + class_types.add(spec.node_name) + + return class_types + + +async def notify_execution_graph(needed_class_types: Set[str]) -> None: + """Evict running extensions not needed for current execution.""" + + async def _stop_extension( + ext_name: str, extension: "ComfyNodeExtension", reason: str + ) -> None: + logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason) + logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name) + stop_result = extension.stop() + if inspect.isawaitable(stop_result): + await stop_result + _RUNNING_EXTENSIONS.pop(ext_name, None) + logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name) + scan_shm_forensics("ISO:stop_extension", refresh_model_context=True) + + scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True) + logger.debug( + "%s ISO:notify_graph_start running=%d needed=%d", + LOG_PREFIX, + len(_RUNNING_EXTENSIONS), + len(needed_class_types), + ) + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + ext_class_types = _get_class_types_for_extension(ext_name) + + # If NONE of this extension's nodes are in the execution graph → evict + if not ext_class_types.intersection(needed_class_types): + await _stop_extension( + ext_name, + extension, + "isolated custom_node not in execution graph, evicting", + ) + + # Isolated child processes add steady VRAM pressure; reclaim host-side models + # at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom. + try: + import comfy.model_management as model_management + + device = model_management.get_torch_device() + if getattr(device, "type", None) == "cuda": + required = max( + model_management.minimum_inference_memory(), + _WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES, + ) + free_before = model_management.get_free_memory(device) + if free_before < required and _RUNNING_EXTENSIONS: + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + await _stop_extension( + ext_name, + extension, + f"boundary low-vram restart (free={int(free_before)} target={int(required)})", + ) + if model_management.get_free_memory(device) < required: + model_management.unload_all_models() + model_management.cleanup_models_gc() + model_management.cleanup_models() + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.soft_empty_cache() + except Exception: + logger.debug( + "%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True + ) + finally: + scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True) + logger.debug( + "%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS) + ) + + +async def flush_running_extensions_transport_state() -> int: + total_flushed = 0 + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + flush_fn = getattr(extension, "flush_transport_state", None) + if not callable(flush_fn): + continue + try: + flushed = await flush_fn() + if isinstance(flushed, int): + total_flushed += flushed + if flushed > 0: + logger.debug( + "%s %s workflow-end flush released=%d", + LOG_PREFIX, + ext_name, + flushed, + ) + except Exception: + logger.debug( + "%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True + ) + scan_shm_forensics( + "ISO:flush_running_extensions_transport_state", refresh_model_context=True + ) + return total_flushed + + +def get_claimed_paths() -> Set[Path]: + return _CLAIMED_PATHS + + +def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None: + """Update all active RPC instances with the current event loop. + + This MUST be called at the start of each workflow execution to ensure + RPC calls are scheduled on the correct event loop. This handles the case + where asyncio.run() creates a new event loop for each workflow. + + Args: + loop: The event loop to use. If None, uses asyncio.get_running_loop(). + """ + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.get_event_loop() + + update_count = 0 + + # Update RPCs from ExtensionManagers + for manager in _EXTENSION_MANAGERS: + if not hasattr(manager, "extensions"): + continue + for name, extension in manager.extensions.items(): + if hasattr(extension, "rpc") and extension.rpc is not None: + if hasattr(extension.rpc, "update_event_loop"): + extension.rpc.update_event_loop(loop) + update_count += 1 + logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'") + + # Also update RPCs from running extensions (they may have direct RPC refs) + for name, extension in _RUNNING_EXTENSIONS.items(): + if hasattr(extension, "rpc") and extension.rpc is not None: + if hasattr(extension.rpc, "update_event_loop"): + extension.rpc.update_event_loop(loop) + update_count += 1 + logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'") + + if update_count > 0: + logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances") + else: + logger.debug( + f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})" + ) + + +__all__ = [ + "LOG_PREFIX", + "initialize_proxies", + "initialize_isolation_nodes", + "start_isolation_loading_early", + "await_isolation_loading", + "notify_execution_graph", + "flush_running_extensions_transport_state", + "get_claimed_paths", + "update_rpc_event_loops", + "IsolatedNodeSpec", + "get_class_types_for_extension", +] diff --git a/comfy/isolation/adapter.py b/comfy/isolation/adapter.py new file mode 100644 index 000000000..2dea2f0f0 --- /dev/null +++ b/comfy/isolation/adapter.py @@ -0,0 +1,505 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped] +from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped] + +try: + from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry + from comfy.isolation.model_patcher_proxy import ( + ModelPatcherProxy, + ModelPatcherRegistry, + ) + from comfy.isolation.model_sampling_proxy import ( + ModelSamplingProxy, + ModelSamplingRegistry, + ) + from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + from comfy.isolation.proxies.prompt_server_impl import PromptServerService + from comfy.isolation.proxies.utils_proxy import UtilsProxy + from comfy.isolation.proxies.progress_proxy import ProgressProxy +except ImportError as exc: # Fail loud if Comfy environment is incomplete + raise ImportError(f"ComfyUI environment incomplete: {exc}") + +logger = logging.getLogger(__name__) + +# Force /dev/shm for shared memory (bwrap makes /tmp private) +import tempfile + +if os.path.exists("/dev/shm"): + # Only override if not already set or if default is not /dev/shm + current_tmp = tempfile.gettempdir() + if not current_tmp.startswith("/dev/shm"): + logger.debug( + f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm" + ) + os.environ["TMPDIR"] = "/dev/shm" + tempfile.tempdir = None # Clear cache to force re-evaluation + + +class ComfyUIAdapter(IsolationAdapter): + # ComfyUI-specific IsolationAdapter implementation + + @property + def identifier(self) -> str: + return "comfyui" + + def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]: + if "ComfyUI" in module_path and "custom_nodes" in module_path: + parts = module_path.split("ComfyUI") + if len(parts) > 1: + comfy_root = parts[0] + "ComfyUI" + return { + "preferred_root": comfy_root, + "additional_paths": [ + os.path.join(comfy_root, "custom_nodes"), + os.path.join(comfy_root, "comfy"), + ], + } + return None + + def setup_child_environment(self, snapshot: Dict[str, Any]) -> None: + comfy_root = snapshot.get("preferred_root") + if not comfy_root: + return + + requirements_path = Path(comfy_root) / "requirements.txt" + if requirements_path.exists(): + import re + + for line in requirements_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + pkg_name = re.split(r"[<>=!~\[]", line)[0].strip() + if pkg_name: + logging.getLogger(pkg_name).setLevel(logging.ERROR) + + def register_serializers(self, registry: SerializerRegistryProtocol) -> None: + def serialize_model_patcher(obj: Any) -> Dict[str, Any]: + # Child-side: must already have _instance_id (proxy) + if os.environ.get("PYISOLATE_CHILD") == "1": + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} + raise RuntimeError( + f"ModelPatcher in child lacks _instance_id: " + f"{type(obj).__module__}.{type(obj).__name__}" + ) + # Host-side: register with registry + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} + model_id = ModelPatcherRegistry().register(obj) + return {"__type__": "ModelPatcherRef", "model_id": model_id} + + def deserialize_model_patcher(data: Any) -> Any: + """Deserialize ModelPatcher refs; pass through already-materialized objects.""" + if isinstance(data, dict): + return ModelPatcherProxy( + data["model_id"], registry=None, manage_lifecycle=False + ) + return data + + def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any: + """Context-aware ModelPatcherRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return ModelPatcherProxy( + data["model_id"], registry=None, manage_lifecycle=False + ) + else: + return ModelPatcherRegistry()._get_instance(data["model_id"]) + + # Register ModelPatcher type for serialization + registry.register( + "ModelPatcher", serialize_model_patcher, deserialize_model_patcher + ) + # Register ModelPatcherProxy type (already a proxy, just return ref) + registry.register( + "ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher + ) + # Register ModelPatcherRef for deserialization (context-aware: host or child) + registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref) + + def serialize_clip(obj: Any) -> Dict[str, Any]: + if hasattr(obj, "_instance_id"): + return {"__type__": "CLIPRef", "clip_id": obj._instance_id} + clip_id = CLIPRegistry().register(obj) + return {"__type__": "CLIPRef", "clip_id": clip_id} + + def deserialize_clip(data: Any) -> Any: + if isinstance(data, dict): + return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) + return data + + def deserialize_clip_ref(data: Dict[str, Any]) -> Any: + """Context-aware CLIPRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) + else: + return CLIPRegistry()._get_instance(data["clip_id"]) + + # Register CLIP type for serialization + registry.register("CLIP", serialize_clip, deserialize_clip) + # Register CLIPProxy type (already a proxy, just return ref) + registry.register("CLIPProxy", serialize_clip, deserialize_clip) + # Register CLIPRef for deserialization (context-aware: host or child) + registry.register("CLIPRef", None, deserialize_clip_ref) + + def serialize_vae(obj: Any) -> Dict[str, Any]: + if hasattr(obj, "_instance_id"): + return {"__type__": "VAERef", "vae_id": obj._instance_id} + vae_id = VAERegistry().register(obj) + return {"__type__": "VAERef", "vae_id": vae_id} + + def deserialize_vae(data: Any) -> Any: + if isinstance(data, dict): + return VAEProxy(data["vae_id"]) + return data + + def deserialize_vae_ref(data: Dict[str, Any]) -> Any: + """Context-aware VAERef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + # Child: create a proxy + return VAEProxy(data["vae_id"]) + else: + # Host: lookup real VAE from registry + return VAERegistry()._get_instance(data["vae_id"]) + + # Register VAE type for serialization + registry.register("VAE", serialize_vae, deserialize_vae) + # Register VAEProxy type (already a proxy, just return ref) + registry.register("VAEProxy", serialize_vae, deserialize_vae) + # Register VAERef for deserialization (context-aware: host or child) + registry.register("VAERef", None, deserialize_vae_ref) + + # ModelSampling serialization - handles ModelSampling* types + # copyreg removed - no pickle fallback allowed + + def serialize_model_sampling(obj: Any) -> Dict[str, Any]: + # Child-side: must already have _instance_id (proxy) + if os.environ.get("PYISOLATE_CHILD") == "1": + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} + raise RuntimeError( + f"ModelSampling in child lacks _instance_id: " + f"{type(obj).__module__}.{type(obj).__name__}" + ) + # Host-side: register with ModelSamplingRegistry and return JSON-safe dict + ms_id = ModelSamplingRegistry().register(obj) + return {"__type__": "ModelSamplingRef", "ms_id": ms_id} + + def deserialize_model_sampling(data: Any) -> Any: + """Deserialize ModelSampling refs; pass through already-materialized objects.""" + if isinstance(data, dict): + return ModelSamplingProxy(data["ms_id"]) + return data + + def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any: + """Context-aware ModelSamplingRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return ModelSamplingProxy(data["ms_id"]) + else: + return ModelSamplingRegistry()._get_instance(data["ms_id"]) + + # Register ModelSampling type and proxy + registry.register( + "ModelSamplingDiscrete", + serialize_model_sampling, + deserialize_model_sampling, + ) + registry.register( + "ModelSamplingContinuousEDM", + serialize_model_sampling, + deserialize_model_sampling, + ) + registry.register( + "ModelSamplingContinuousV", + serialize_model_sampling, + deserialize_model_sampling, + ) + registry.register( + "ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling + ) + # Register ModelSamplingRef for deserialization (context-aware: host or child) + registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref) + + def serialize_cond(obj: Any) -> Dict[str, Any]: + type_key = f"{type(obj).__module__}.{type(obj).__name__}" + return { + "__type__": type_key, + "cond": obj.cond, + } + + def deserialize_cond(data: Dict[str, Any]) -> Any: + import importlib + + type_key = data["__type__"] + module_name, class_name = type_key.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + return cls(data["cond"]) + + def _serialize_public_state(obj: Any) -> Dict[str, Any]: + state: Dict[str, Any] = {} + for key, value in obj.__dict__.items(): + if key.startswith("_"): + continue + if callable(value): + continue + state[key] = value + return state + + def serialize_latent_format(obj: Any) -> Dict[str, Any]: + type_key = f"{type(obj).__module__}.{type(obj).__name__}" + return { + "__type__": type_key, + "state": _serialize_public_state(obj), + } + + def deserialize_latent_format(data: Dict[str, Any]) -> Any: + import importlib + + type_key = data["__type__"] + module_name, class_name = type_key.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + obj = cls() + for key, value in data.get("state", {}).items(): + prop = getattr(type(obj), key, None) + if isinstance(prop, property) and prop.fset is None: + continue + setattr(obj, key, value) + return obj + + import comfy.conds + + for cond_cls in vars(comfy.conds).values(): + if not isinstance(cond_cls, type): + continue + if not issubclass(cond_cls, comfy.conds.CONDRegular): + continue + type_key = f"{cond_cls.__module__}.{cond_cls.__name__}" + registry.register(type_key, serialize_cond, deserialize_cond) + registry.register(cond_cls.__name__, serialize_cond, deserialize_cond) + + import comfy.latent_formats + + for latent_cls in vars(comfy.latent_formats).values(): + if not isinstance(latent_cls, type): + continue + if not issubclass(latent_cls, comfy.latent_formats.LatentFormat): + continue + type_key = f"{latent_cls.__module__}.{latent_cls.__name__}" + registry.register( + type_key, serialize_latent_format, deserialize_latent_format + ) + registry.register( + latent_cls.__name__, serialize_latent_format, deserialize_latent_format + ) + + # V3 API: unwrap NodeOutput.args + def deserialize_node_output(data: Any) -> Any: + return getattr(data, "args", data) + + registry.register("NodeOutput", None, deserialize_node_output) + + # KSAMPLER serializer: stores sampler name instead of function object + # sampler_function is a callable which gets filtered out by JSONSocketTransport + def serialize_ksampler(obj: Any) -> Dict[str, Any]: + func_name = obj.sampler_function.__name__ + # Map function name back to sampler name + if func_name == "sample_unipc": + sampler_name = "uni_pc" + elif func_name == "sample_unipc_bh2": + sampler_name = "uni_pc_bh2" + elif func_name == "dpm_fast_function": + sampler_name = "dpm_fast" + elif func_name == "dpm_adaptive_function": + sampler_name = "dpm_adaptive" + elif func_name.startswith("sample_"): + sampler_name = func_name[7:] # Remove "sample_" prefix + else: + sampler_name = func_name + return { + "__type__": "KSAMPLER", + "sampler_name": sampler_name, + "extra_options": obj.extra_options, + "inpaint_options": obj.inpaint_options, + } + + def deserialize_ksampler(data: Dict[str, Any]) -> Any: + import comfy.samplers + + return comfy.samplers.ksampler( + data["sampler_name"], + data.get("extra_options", {}), + data.get("inpaint_options", {}), + ) + + registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler) + + from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers + + register_hooks_serializers(registry) + + # Generic Numpy Serializer + def serialize_numpy(obj: Any) -> Any: + import torch + + try: + # Attempt zero-copy conversion to Tensor + return torch.from_numpy(obj) + except Exception: + # Fallback for non-numeric arrays (strings, objects, mixes) + return obj.tolist() + + registry.register("ndarray", serialize_numpy, None) + + def provide_rpc_services(self) -> List[type[ProxiedSingleton]]: + return [ + PromptServerService, + FolderPathsProxy, + ModelManagementProxy, + UtilsProxy, + ProgressProxy, + VAERegistry, + CLIPRegistry, + ModelPatcherRegistry, + ModelSamplingRegistry, + FirstStageModelRegistry, + ] + + def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None: + # Resolve the real name whether it's an instance or the Singleton class itself + api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__ + + if api_name == "FolderPathsProxy": + import folder_paths + + # Replace module-level functions with proxy methods + # This is aggressive but necessary for transparent proxying + # Handle both instance and class cases + instance = api() if isinstance(api, type) else api + for name in dir(instance): + if not name.startswith("_"): + setattr(folder_paths, name, getattr(instance, name)) + return + + if api_name == "ModelManagementProxy": + import comfy.model_management + + instance = api() if isinstance(api, type) else api + # Replace module-level functions with proxy methods + for name in dir(instance): + if not name.startswith("_"): + setattr(comfy.model_management, name, getattr(instance, name)) + return + + if api_name == "UtilsProxy": + import comfy.utils + + # Static Injection of RPC mechanism to ensure Child can access it + # independent of instance lifecycle. + api.set_rpc(rpc) + + # Don't overwrite host hook (infinite recursion) + return + + if api_name == "PromptServerProxy": + # Defer heavy import to child context + import server + + instance = api() if isinstance(api, type) else api + proxy = ( + instance.instance + ) # PromptServerProxy instance has .instance property returning self + + original_register_route = proxy.register_route + + def register_route_wrapper( + method: str, path: str, handler: Callable[..., Any] + ) -> None: + callback_id = rpc.register_callback(handler) + loop = getattr(rpc, "loop", None) + if loop and loop.is_running(): + import asyncio + + asyncio.create_task( + original_register_route( + method, path, handler=callback_id, is_callback=True + ) + ) + else: + original_register_route( + method, path, handler=callback_id, is_callback=True + ) + return None + + proxy.register_route = register_route_wrapper + + class RouteTableDefProxy: + def __init__(self, proxy_instance: Any): + self.proxy = proxy_instance + + def get( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("GET", path, handler) + return handler + + return decorator + + def post( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("POST", path, handler) + return handler + + return decorator + + def patch( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("PATCH", path, handler) + return handler + + return decorator + + def put( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("PUT", path, handler) + return handler + + return decorator + + def delete( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("DELETE", path, handler) + return handler + + return decorator + + proxy.routes = RouteTableDefProxy(proxy) + + if ( + hasattr(server, "PromptServer") + and getattr(server.PromptServer, "instance", None) != proxy + ): + server.PromptServer.instance = proxy diff --git a/comfy/isolation/child_hooks.py b/comfy/isolation/child_hooks.py new file mode 100644 index 000000000..a1ba201ac --- /dev/null +++ b/comfy/isolation/child_hooks.py @@ -0,0 +1,141 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation +# Child process initialization for PyIsolate +import logging +import os + +logger = logging.getLogger(__name__) + + +def is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +def initialize_child_process() -> None: + # Manual RPC injection + try: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc: + _setup_prompt_server_stub(rpc) + _setup_utils_proxy(rpc) + else: + logger.warning("Could not get child RPC instance for manual injection") + _setup_prompt_server_stub() + _setup_utils_proxy() + except Exception as e: + logger.error(f"Manual RPC Injection failed: {e}") + _setup_prompt_server_stub() + _setup_utils_proxy() + + _setup_logging() + + +def _setup_prompt_server_stub(rpc=None) -> None: + try: + from .proxies.prompt_server_impl import PromptServerStub + import sys + import types + + # Mock server module + if "server" not in sys.modules: + mock_server = types.ModuleType("server") + sys.modules["server"] = mock_server + + server = sys.modules["server"] + + if not hasattr(server, "PromptServer"): + + class MockPromptServer: + pass + + server.PromptServer = MockPromptServer + + stub = PromptServerStub() + + if rpc: + PromptServerStub.set_rpc(rpc) + if hasattr(stub, "set_rpc"): + stub.set_rpc(rpc) + + server.PromptServer.instance = stub + + except Exception as e: + logger.error(f"Failed to setup PromptServerStub: {e}") + + +def _setup_utils_proxy(rpc=None) -> None: + try: + import comfy.utils + import asyncio + + # Capture main loop during initialization (safe context) + main_loop = None + try: + main_loop = asyncio.get_running_loop() + except RuntimeError: + try: + main_loop = asyncio.get_event_loop() + except RuntimeError: + pass + + try: + from .proxies.base import set_global_loop + + if main_loop: + set_global_loop(main_loop) + except ImportError: + pass + + # Sync hook wrapper for progress updates + def sync_hook_wrapper( + value: int, total: int, preview: None = None, node_id: None = None + ) -> None: + if node_id is None: + try: + from comfy_execution.utils import get_executing_context + + ctx = get_executing_context() + if ctx: + node_id = ctx.node_id + else: + pass + except Exception: + pass + + # Bypass blocked event loop by direct outbox injection + if rpc: + try: + # Use captured main loop if available (for threaded execution), or current loop + loop = main_loop + if loop is None: + loop = asyncio.get_event_loop() + + rpc.outbox.put( + { + "kind": "call", + "object_id": "UtilsProxy", + "parent_call_id": None, # We are root here usually + "calling_loop": loop, + "future": loop.create_future(), # Dummy future + "method": "progress_bar_hook", + "args": (value, total, preview, node_id), + "kwargs": {}, + } + ) + + except Exception as e: + logging.getLogger(__name__).error(f"Manual Inject Failed: {e}") + else: + logging.getLogger(__name__).warning( + "No RPC instance available for progress update" + ) + + comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper + + except Exception as e: + logger.error(f"Failed to setup UtilsProxy hook: {e}") + + +def _setup_logging() -> None: + logging.getLogger().setLevel(logging.INFO) diff --git a/comfy/isolation/extension_loader.py b/comfy/isolation/extension_loader.py new file mode 100644 index 000000000..55db00ade --- /dev/null +++ b/comfy/isolation/extension_loader.py @@ -0,0 +1,248 @@ +# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name +from __future__ import annotations + +import logging +import os +import inspect +import sys +import types +import platform +from pathlib import Path +from typing import Callable, Dict, List, Tuple + +import pyisolate +from pyisolate import ExtensionManager, ExtensionManagerConfig + +from .extension_wrapper import ComfyNodeExtension +from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache +from .host_policy import load_host_policy + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +logger = logging.getLogger(__name__) + + +async def _stop_extension_safe( + extension: ComfyNodeExtension, extension_name: str +) -> None: + try: + stop_result = extension.stop() + if inspect.isawaitable(stop_result): + await stop_result + except Exception: + logger.debug("][ %s stop failed", extension_name, exc_info=True) + + +def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str: + req, sep, marker = dep.partition(";") + req = req.strip() + marker_suffix = f";{marker}" if sep else "" + + def _resolve_local_path(local_path: str) -> Path | None: + for base in base_paths: + candidate = (base / local_path).resolve() + if candidate.exists(): + return candidate + return None + + if req.startswith("./") or req.startswith("../"): + resolved = _resolve_local_path(req) + if resolved is not None: + return f"{resolved}{marker_suffix}" + + if req.startswith("file://"): + raw = req[len("file://") :] + if raw.startswith("./") or raw.startswith("../"): + resolved = _resolve_local_path(raw) + if resolved is not None: + return f"file://{resolved}{marker_suffix}" + + return dep + + +def get_enforcement_policy() -> Dict[str, bool]: + return { + "force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1", + "force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1", + } + + +class ExtensionLoadError(RuntimeError): + pass + + +def register_dummy_module(extension_name: str, node_dir: Path) -> None: + normalized_name = extension_name.replace("-", "_").replace(".", "_") + if normalized_name not in sys.modules: + dummy_module = types.ModuleType(normalized_name) + dummy_module.__file__ = str(node_dir / "__init__.py") + dummy_module.__path__ = [str(node_dir)] + dummy_module.__package__ = normalized_name + sys.modules[normalized_name] = dummy_module + + +def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool: + for details in cached_data.values(): + if not isinstance(details, dict): + return True + if details.get("is_v3") and "schema_v1" not in details: + return True + return False + + +async def load_isolated_node( + node_dir: Path, + manifest_path: Path, + logger: logging.Logger, + build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type], + venv_root: Path, + extension_managers: List[ExtensionManager], +) -> List[Tuple[str, str, type]]: + try: + with manifest_path.open("rb") as handle: + manifest_data = tomllib.load(handle) + except Exception as e: + logger.warning(f"][ Failed to parse {manifest_path}: {e}") + return [] + + # Parse [tool.comfy.isolation] + tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {}) + can_isolate = tool_config.get("can_isolate", False) + share_torch = tool_config.get("share_torch", False) + + # Parse [project] dependencies + project_config = manifest_data.get("project", {}) + dependencies = project_config.get("dependencies", []) + if not isinstance(dependencies, list): + dependencies = [] + + # Get extension name (default to folder name if not in project.name) + extension_name = project_config.get("name", node_dir.name) + + # LOGIC: Isolation Decision + policy = get_enforcement_policy() + isolated = can_isolate or policy["force_isolated"] + + if not isolated: + return [] + + logger.info(f"][ Loading isolated node: {extension_name}") + + import folder_paths + + base_paths = [Path(folder_paths.base_path), node_dir] + dependencies = [ + _normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep + for dep in dependencies + ] + + manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root)) + manager: ExtensionManager = pyisolate.ExtensionManager( + ComfyNodeExtension, manager_config + ) + extension_managers.append(manager) + + host_policy = load_host_policy(Path(folder_paths.base_path)) + + sandbox_config = {} + is_linux = platform.system() == "Linux" + if is_linux and isolated: + sandbox_config = { + "network": host_policy["allow_network"], + "writable_paths": host_policy["writable_paths"], + "readonly_paths": host_policy["readonly_paths"], + } + share_cuda_ipc = share_torch and is_linux + + extension_config = { + "name": extension_name, + "module_path": str(node_dir), + "isolated": True, + "dependencies": dependencies, + "share_torch": share_torch, + "share_cuda_ipc": share_cuda_ipc, + "sandbox": sandbox_config, + } + + extension = manager.load_extension(extension_config) + register_dummy_module(extension_name, node_dir) + + # Try cache first (lazy spawn) + if is_cache_valid(node_dir, manifest_path, venv_root): + cached_data = load_from_cache(node_dir, venv_root) + if cached_data: + if _is_stale_node_cache(cached_data): + logger.debug( + "][ %s cache is stale/incompatible; rebuilding metadata", + extension_name, + ) + else: + logger.debug(f"][ {extension_name} loaded from cache") + specs: List[Tuple[str, str, type]] = [] + for node_name, details in cached_data.items(): + stub_cls = build_stub_class(node_name, details, extension) + specs.append( + (node_name, details.get("display_name", node_name), stub_cls) + ) + return specs + + # Cache miss - spawn process and get metadata + logger.debug(f"][ {extension_name} cache miss, spawning process for metadata") + + try: + remote_nodes: Dict[str, str] = await extension.list_nodes() + except Exception as exc: + logger.warning( + "][ %s metadata discovery failed, skipping isolated load: %s", + extension_name, + exc, + ) + await _stop_extension_safe(extension, extension_name) + return [] + + if not remote_nodes: + logger.debug("][ %s exposed no isolated nodes; skipping", extension_name) + await _stop_extension_safe(extension, extension_name) + return [] + + specs: List[Tuple[str, str, type]] = [] + cache_data: Dict[str, Dict] = {} + + for node_name, display_name in remote_nodes.items(): + try: + details = await extension.get_node_details(node_name) + except Exception as exc: + logger.warning( + "][ %s failed to load metadata for %s, skipping node: %s", + extension_name, + node_name, + exc, + ) + continue + details["display_name"] = display_name + cache_data[node_name] = details + stub_cls = build_stub_class(node_name, details, extension) + specs.append((node_name, display_name, stub_cls)) + + if not specs: + logger.warning( + "][ %s produced no usable nodes after metadata scan; skipping", + extension_name, + ) + await _stop_extension_safe(extension, extension_name) + return [] + + # Save metadata to cache for future runs + save_to_cache(node_dir, venv_root, cache_data, manifest_path) + logger.debug(f"][ {extension_name} metadata cached") + + # EJECT: Kill process after getting metadata (will respawn on first execution) + await _stop_extension_safe(extension, extension_name) + + return specs + + +__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"] diff --git a/comfy/isolation/extension_wrapper.py b/comfy/isolation/extension_wrapper.py new file mode 100644 index 000000000..23148e470 --- /dev/null +++ b/comfy/isolation/extension_wrapper.py @@ -0,0 +1,673 @@ +# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position +from __future__ import annotations + +import asyncio +import torch + + +class AttrDict(dict): + def __getattr__(self, item): + try: + return self[item] + except KeyError as e: + raise AttributeError(item) from e + + def copy(self): + return AttrDict(super().copy()) + + +import importlib +import inspect +import json +import logging +import os +import sys +import uuid +from dataclasses import asdict +from typing import Any, Dict, List, Tuple + +from pyisolate import ExtensionBase + +from comfy_api.internal import _ComfyNodeInternal + +LOG_PREFIX = "][" +V3_DISCOVERY_TIMEOUT = 30 +_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + +logger = logging.getLogger(__name__) + + +def _flush_tensor_transport_state(marker: str) -> int: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + return 0 + if not callable(flush_tensor_keeper): + return 0 + flushed = flush_tensor_keeper() + if flushed > 0: + logger.debug( + "%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed + ) + return flushed + + +def _relieve_child_vram_pressure(marker: str) -> None: + import comfy.model_management as model_management + + model_management.cleanup_models_gc() + model_management.cleanup_models() + + device = model_management.get_torch_device() + if not hasattr(device, "type") or device.type == "cpu": + return + + required = max( + model_management.minimum_inference_memory(), + _PRE_EXEC_MIN_FREE_VRAM_BYTES, + ) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=True) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.cleanup_models() + model_management.soft_empty_cache() + logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required) + + +def _sanitize_for_transport(value): + primitives = (str, int, float, bool, type(None)) + if isinstance(value, primitives): + return value + + cls_name = value.__class__.__name__ + if cls_name == "FlexibleOptionalInputType": + return { + "__pyisolate_flexible_optional__": True, + "type": _sanitize_for_transport(getattr(value, "type", "*")), + } + if cls_name == "AnyType": + return {"__pyisolate_any_type__": True, "value": str(value)} + if cls_name == "ByPassTypeTuple": + return { + "__pyisolate_bypass_tuple__": [ + _sanitize_for_transport(v) for v in tuple(value) + ] + } + + if isinstance(value, dict): + return {k: _sanitize_for_transport(v) for k, v in value.items()} + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]} + if isinstance(value, list): + return [_sanitize_for_transport(v) for v in value] + + return str(value) + + +# Re-export RemoteObjectHandle from pyisolate for backward compatibility +# The canonical definition is now in pyisolate._internal.remote_handle +from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401 + + +class ComfyNodeExtension(ExtensionBase): + def __init__(self) -> None: + super().__init__() + self.node_classes: Dict[str, type] = {} + self.display_names: Dict[str, str] = {} + self.node_instances: Dict[str, Any] = {} + self.remote_objects: Dict[str, Any] = {} + self._route_handlers: Dict[str, Any] = {} + self._module: Any = None + + async def on_module_loaded(self, module: Any) -> None: + self._module = module + + # Registries are initialized in host_hooks.py initialize_host_process() + # They auto-register via ProxiedSingleton when instantiated + # NO additional setup required here - if a registry is missing from host_hooks, it WILL fail + + self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {} + self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {} + + try: + from comfy_api.latest import ComfyExtension + + for name, obj in inspect.getmembers(module): + if not ( + inspect.isclass(obj) + and issubclass(obj, ComfyExtension) + and obj is not ComfyExtension + ): + continue + if not obj.__module__.startswith(module.__name__): + continue + try: + ext_instance = obj() + try: + await asyncio.wait_for( + ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT + ) + except asyncio.TimeoutError: + logger.error( + "%s V3 Extension %s timed out in on_load()", + LOG_PREFIX, + name, + ) + continue + try: + v3_nodes = await asyncio.wait_for( + ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT + ) + except asyncio.TimeoutError: + logger.error( + "%s V3 Extension %s timed out in get_node_list()", + LOG_PREFIX, + name, + ) + continue + for node_cls in v3_nodes: + if hasattr(node_cls, "GET_SCHEMA"): + schema = node_cls.GET_SCHEMA() + self.node_classes[schema.node_id] = node_cls + if schema.display_name: + self.display_names[schema.node_id] = schema.display_name + except Exception as e: + logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e) + except ImportError: + pass + + module_name = getattr(module, "__name__", "isolated_nodes") + for node_cls in self.node_classes.values(): + if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__): + node_cls.__module__ = module_name + + self.node_instances = {} + + async def list_nodes(self) -> Dict[str, str]: + return {name: self.display_names.get(name, name) for name in self.node_classes} + + async def get_node_info(self, node_name: str) -> Dict[str, Any]: + return await self.get_node_details(node_name) + + async def get_node_details(self, node_name: str) -> Dict[str, Any]: + node_cls = self._get_node_class(node_name) + is_v3 = issubclass(node_cls, _ComfyNodeInternal) + + input_types_raw = ( + node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {} + ) + output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None) + if output_is_list is not None: + output_is_list = tuple(bool(x) for x in output_is_list) + + details: Dict[str, Any] = { + "input_types": _sanitize_for_transport(input_types_raw), + "return_types": tuple( + str(t) for t in getattr(node_cls, "RETURN_TYPES", ()) + ), + "return_names": getattr(node_cls, "RETURN_NAMES", None), + "function": str(getattr(node_cls, "FUNCTION", "execute")), + "category": str(getattr(node_cls, "CATEGORY", "")), + "output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)), + "output_is_list": output_is_list, + "is_v3": is_v3, + } + + if is_v3: + try: + schema = node_cls.GET_SCHEMA() + schema_v1 = asdict(schema.get_v1_info(node_cls)) + try: + schema_v3 = asdict(schema.get_v3_info(node_cls)) + except (AttributeError, TypeError): + schema_v3 = self._build_schema_v3_fallback(schema) + details.update( + { + "schema_v1": schema_v1, + "schema_v3": schema_v3, + "hidden": [h.value for h in (schema.hidden or [])], + "description": getattr(schema, "description", ""), + "deprecated": bool(getattr(node_cls, "DEPRECATED", False)), + "experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)), + "api_node": bool(getattr(node_cls, "API_NODE", False)), + "input_is_list": bool( + getattr(node_cls, "INPUT_IS_LIST", False) + ), + "not_idempotent": bool( + getattr(node_cls, "NOT_IDEMPOTENT", False) + ), + } + ) + except Exception as exc: + logger.warning( + "%s V3 schema serialization failed for %s: %s", + LOG_PREFIX, + node_name, + exc, + ) + return details + + def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]: + input_dict: Dict[str, Any] = {} + output_dict: Dict[str, Any] = {} + hidden_list: List[str] = [] + + if getattr(schema, "inputs", None): + for inp in schema.inputs: + self._add_schema_io_v3(inp, input_dict) + if getattr(schema, "outputs", None): + for out in schema.outputs: + self._add_schema_io_v3(out, output_dict) + if getattr(schema, "hidden", None): + for h in schema.hidden: + hidden_list.append(getattr(h, "value", str(h))) + + return { + "input": input_dict, + "output": output_dict, + "hidden": hidden_list, + "name": getattr(schema, "node_id", None), + "display_name": getattr(schema, "display_name", None), + "description": getattr(schema, "description", None), + "category": getattr(schema, "category", None), + "output_node": getattr(schema, "is_output_node", False), + "deprecated": getattr(schema, "is_deprecated", False), + "experimental": getattr(schema, "is_experimental", False), + "api_node": getattr(schema, "is_api_node", False), + } + + def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None: + io_id = getattr(io_obj, "id", None) + if io_id is None: + return + + io_type_fn = getattr(io_obj, "get_io_type", None) + io_type = ( + io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None) + ) + + as_dict_fn = getattr(io_obj, "as_dict", None) + payload = as_dict_fn() if callable(as_dict_fn) else {} + + target[str(io_id)] = (io_type, payload) + + async def get_input_types(self, node_name: str) -> Dict[str, Any]: + node_cls = self._get_node_class(node_name) + if hasattr(node_cls, "INPUT_TYPES"): + return node_cls.INPUT_TYPES() + return {} + + async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]: + logger.debug( + "%s ISO:child_execute_start ext=%s node=%s input_keys=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + len(inputs), + ) + if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1": + _relieve_child_vram_pressure("EXT:pre_execute") + + resolved_inputs = self._resolve_remote_objects(inputs) + + instance = self._get_node_instance(node_name) + node_cls = self._get_node_class(node_name) + + # V3 API nodes expect hidden parameters in cls.hidden, not as kwargs + # Hidden params come through RPC as string keys like "Hidden.prompt" + from comfy_api.latest._io import Hidden, HiddenHolder + + # Map string representations back to Hidden enum keys + hidden_string_map = { + "Hidden.unique_id": Hidden.unique_id, + "Hidden.prompt": Hidden.prompt, + "Hidden.extra_pnginfo": Hidden.extra_pnginfo, + "Hidden.dynprompt": Hidden.dynprompt, + "Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org, + "Hidden.api_key_comfy_org": Hidden.api_key_comfy_org, + } + + # Find and extract hidden parameters (both enum and string form) + hidden_found = {} + keys_to_remove = [] + + for key in list(resolved_inputs.keys()): + # Check string form first (from RPC serialization) + if key in hidden_string_map: + hidden_found[hidden_string_map[key]] = resolved_inputs[key] + keys_to_remove.append(key) + # Also check enum form (direct calls) + elif isinstance(key, Hidden): + hidden_found[key] = resolved_inputs[key] + keys_to_remove.append(key) + + # Remove hidden params from kwargs + for key in keys_to_remove: + resolved_inputs.pop(key) + + # Set hidden on node class if any hidden params found + if hidden_found: + if not hasattr(node_cls, "hidden") or node_cls.hidden is None: + node_cls.hidden = HiddenHolder.from_dict(hidden_found) + else: + # Update existing hidden holder + for key, value in hidden_found.items(): + setattr(node_cls.hidden, key.value.lower(), value) + + function_name = getattr(node_cls, "FUNCTION", "execute") + if not hasattr(instance, function_name): + raise AttributeError(f"Node {node_name} missing callable '{function_name}'") + + handler = getattr(instance, function_name) + + try: + if asyncio.iscoroutinefunction(handler): + result = await handler(**resolved_inputs) + else: + import functools + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + None, functools.partial(handler, **resolved_inputs) + ) + except Exception: + logger.exception( + "%s ISO:child_execute_error ext=%s node=%s", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + ) + raise + + if type(result).__name__ == "NodeOutput": + result = result.args + if self._is_comfy_protocol_return(result): + logger.debug( + "%s ISO:child_execute_done ext=%s node=%s protocol_return=1", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + ) + return self._wrap_unpicklable_objects(result) + + if not isinstance(result, tuple): + result = (result,) + logger.debug( + "%s ISO:child_execute_done ext=%s node=%s protocol_return=0 outputs=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + len(result), + ) + return self._wrap_unpicklable_objects(result) + + async def flush_transport_state(self) -> int: + if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1": + return 0 + logger.debug( + "%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?") + ) + flushed = _flush_tensor_transport_state("EXT:workflow_end") + try: + from comfy.isolation.model_patcher_proxy_registry import ( + ModelPatcherRegistry, + ) + + registry = ModelPatcherRegistry() + removed = registry.sweep_pending_cleanup() + if removed > 0: + logger.debug( + "%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed + ) + except Exception: + logger.debug( + "%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True + ) + logger.debug( + "%s ISO:child_flush_done ext=%s flushed=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + flushed, + ) + return flushed + + async def get_remote_object(self, object_id: str) -> Any: + """Retrieve a remote object by ID for host-side deserialization.""" + if object_id not in self.remote_objects: + raise KeyError(f"Remote object {object_id} not found") + + return self.remote_objects[object_id] + + def _wrap_unpicklable_objects(self, data: Any) -> Any: + if isinstance(data, (str, int, float, bool, type(None))): + return data + if isinstance(data, torch.Tensor): + return data.detach() if data.requires_grad else data + + # Special-case clip vision outputs: preserve attribute access by packing fields + if hasattr(data, "penultimate_hidden_states") or hasattr( + data, "last_hidden_state" + ): + fields = {} + for attr in ( + "penultimate_hidden_states", + "last_hidden_state", + "image_embeds", + "text_embeds", + ): + if hasattr(data, attr): + try: + fields[attr] = self._wrap_unpicklable_objects( + getattr(data, attr) + ) + except Exception: + pass + if fields: + return {"__pyisolate_attribute_container__": True, "data": fields} + + # Avoid converting arbitrary objects with stateful methods (models, etc.) + # They will be handled via RemoteObjectHandle below. + + type_name = type(data).__name__ + if type_name == "ModelPatcherProxy": + return {"__type__": "ModelPatcherRef", "model_id": data._instance_id} + if type_name == "CLIPProxy": + return {"__type__": "CLIPRef", "clip_id": data._instance_id} + if type_name == "VAEProxy": + return {"__type__": "VAERef", "vae_id": data._instance_id} + if type_name == "ModelSamplingProxy": + return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id} + + if isinstance(data, (list, tuple)): + wrapped = [self._wrap_unpicklable_objects(item) for item in data] + return tuple(wrapped) if isinstance(data, tuple) else wrapped + if isinstance(data, dict): + converted_dict = { + k: self._wrap_unpicklable_objects(v) for k, v in data.items() + } + return {"__pyisolate_attrdict__": True, "data": converted_dict} + + object_id = str(uuid.uuid4()) + self.remote_objects[object_id] = data + return RemoteObjectHandle(object_id, type(data).__name__) + + def _resolve_remote_objects(self, data: Any) -> Any: + if isinstance(data, RemoteObjectHandle): + if data.object_id not in self.remote_objects: + raise KeyError(f"Remote object {data.object_id} not found") + return self.remote_objects[data.object_id] + + if isinstance(data, dict): + ref_type = data.get("__type__") + if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"): + from pyisolate._internal.model_serialization import ( + deserialize_proxy_result, + ) + + return deserialize_proxy_result(data) + if ref_type == "ModelSamplingRef": + from pyisolate._internal.model_serialization import ( + deserialize_proxy_result, + ) + + return deserialize_proxy_result(data) + return {k: self._resolve_remote_objects(v) for k, v in data.items()} + + if isinstance(data, (list, tuple)): + resolved = [self._resolve_remote_objects(item) for item in data] + return tuple(resolved) if isinstance(data, tuple) else resolved + return data + + def _get_node_class(self, node_name: str) -> type: + if node_name not in self.node_classes: + raise KeyError(f"Unknown node: {node_name}") + return self.node_classes[node_name] + + def _get_node_instance(self, node_name: str) -> Any: + if node_name not in self.node_instances: + if node_name not in self.node_classes: + raise KeyError(f"Unknown node: {node_name}") + self.node_instances[node_name] = self.node_classes[node_name]() + return self.node_instances[node_name] + + async def before_module_loaded(self) -> None: + # Inject initialization here if we think this is the child + try: + from comfy.isolation import initialize_proxies + + initialize_proxies() + except Exception as e: + logging.getLogger(__name__).error( + f"Failed to call initialize_proxies in before_module_loaded: {e}" + ) + + await super().before_module_loaded() + try: + from comfy_api.latest import ComfyAPI_latest + from .proxies.progress_proxy import ProgressProxy + + ComfyAPI_latest.Execution = ProgressProxy + # ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision + # fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision + # latest_ui.folder_paths = fp_proxy + # latest_resources.folder_paths = fp_proxy + except Exception: + pass + + async def call_route_handler( + self, + handler_module: str, + handler_func: str, + request_data: Dict[str, Any], + ) -> Any: + cache_key = f"{handler_module}.{handler_func}" + if cache_key not in self._route_handlers: + if self._module is not None and hasattr(self._module, "__file__"): + node_dir = os.path.dirname(self._module.__file__) + if node_dir not in sys.path: + sys.path.insert(0, node_dir) + try: + module = importlib.import_module(handler_module) + self._route_handlers[cache_key] = getattr(module, handler_func) + except (ImportError, AttributeError) as e: + raise ValueError(f"Route handler not found: {cache_key}") from e + + handler = self._route_handlers[cache_key] + mock_request = MockRequest(request_data) + + if asyncio.iscoroutinefunction(handler): + result = await handler(mock_request) + else: + result = handler(mock_request) + return self._serialize_response(result) + + def _is_comfy_protocol_return(self, result: Any) -> bool: + """ + Check if the result matches the ComfyUI 'Protocol Return' schema. + + A Protocol Return is a dictionary containing specific reserved keys that + ComfyUI's execution engine interprets as instructions (UI updates, + Workflow expansion, etc.) rather than purely data outputs. + + Schema: + - Must be a dict + - Must contain at least one of: 'ui', 'result', 'expand' + """ + if not isinstance(result, dict): + return False + return any(key in result for key in ("ui", "result", "expand")) + + def _serialize_response(self, response: Any) -> Dict[str, Any]: + if response is None: + return {"type": "text", "body": "", "status": 204} + if isinstance(response, dict): + return {"type": "json", "body": response, "status": 200} + if isinstance(response, str): + return {"type": "text", "body": response, "status": 200} + if hasattr(response, "text") and hasattr(response, "status"): + return { + "type": "text", + "body": response.text + if hasattr(response, "text") + else str(response.body), + "status": response.status, + "headers": dict(response.headers) + if hasattr(response, "headers") + else {}, + } + if hasattr(response, "body") and hasattr(response, "status"): + body = response.body + if isinstance(body, bytes): + try: + return { + "type": "text", + "body": body.decode("utf-8"), + "status": response.status, + } + except UnicodeDecodeError: + return { + "type": "binary", + "body": body.hex(), + "status": response.status, + } + return {"type": "json", "body": body, "status": response.status} + return {"type": "text", "body": str(response), "status": 200} + + +class MockRequest: + def __init__(self, data: Dict[str, Any]): + self.method = data.get("method", "GET") + self.path = data.get("path", "/") + self.query = data.get("query", {}) + self._body = data.get("body", {}) + self._text = data.get("text", "") + self.headers = data.get("headers", {}) + self.content_type = data.get( + "content_type", self.headers.get("Content-Type", "application/json") + ) + self.match_info = data.get("match_info", {}) + + async def json(self) -> Any: + if isinstance(self._body, dict): + return self._body + if isinstance(self._body, str): + return json.loads(self._body) + return {} + + async def post(self) -> Dict[str, Any]: + if isinstance(self._body, dict): + return self._body + return {} + + async def text(self) -> str: + if self._text: + return self._text + if isinstance(self._body, str): + return self._body + if isinstance(self._body, dict): + return json.dumps(self._body) + return "" + + async def read(self) -> bytes: + return (await self.text()).encode("utf-8") diff --git a/comfy/isolation/host_hooks.py b/comfy/isolation/host_hooks.py new file mode 100644 index 000000000..86cde10a8 --- /dev/null +++ b/comfy/isolation/host_hooks.py @@ -0,0 +1,26 @@ +# pylint: disable=import-outside-toplevel +# Host process initialization for PyIsolate +import logging + +logger = logging.getLogger(__name__) + + +def initialize_host_process() -> None: + root = logging.getLogger() + for handler in root.handlers[:]: + root.removeHandler(handler) + root.addHandler(logging.NullHandler()) + + from .proxies.folder_paths_proxy import FolderPathsProxy + from .proxies.model_management_proxy import ModelManagementProxy + from .proxies.progress_proxy import ProgressProxy + from .proxies.prompt_server_impl import PromptServerService + from .proxies.utils_proxy import UtilsProxy + from .vae_proxy import VAERegistry + + FolderPathsProxy() + ModelManagementProxy() + ProgressProxy() + PromptServerService() + UtilsProxy() + VAERegistry() diff --git a/comfy/isolation/host_policy.py b/comfy/isolation/host_policy.py new file mode 100644 index 000000000..660dcda20 --- /dev/null +++ b/comfy/isolation/host_policy.py @@ -0,0 +1,83 @@ +# pylint: disable=logging-fstring-interpolation +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Dict, List, TypedDict + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +logger = logging.getLogger(__name__) + + +class HostSecurityPolicy(TypedDict): + allow_network: bool + writable_paths: List[str] + readonly_paths: List[str] + whitelist: Dict[str, str] + + +DEFAULT_POLICY: HostSecurityPolicy = { + "allow_network": False, + "writable_paths": ["/dev/shm", "/tmp"], + "readonly_paths": [], + "whitelist": {}, +} + + +def _default_policy() -> HostSecurityPolicy: + return { + "allow_network": DEFAULT_POLICY["allow_network"], + "writable_paths": list(DEFAULT_POLICY["writable_paths"]), + "readonly_paths": list(DEFAULT_POLICY["readonly_paths"]), + "whitelist": dict(DEFAULT_POLICY["whitelist"]), + } + + +def load_host_policy(comfy_root: Path) -> HostSecurityPolicy: + config_path = comfy_root / "pyproject.toml" + policy = _default_policy() + + if not config_path.exists(): + logger.debug("Host policy file missing at %s, using defaults.", config_path) + return policy + + try: + with config_path.open("rb") as f: + data = tomllib.load(f) + except Exception: + logger.warning( + "Failed to parse host policy from %s, using defaults.", + config_path, + exc_info=True, + ) + return policy + + tool_config = data.get("tool", {}).get("comfy", {}).get("host", {}) + if not isinstance(tool_config, dict): + logger.debug("No [tool.comfy.host] section found, using defaults.") + return policy + + if "allow_network" in tool_config: + policy["allow_network"] = bool(tool_config["allow_network"]) + + if "writable_paths" in tool_config: + policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]] + + if "readonly_paths" in tool_config: + policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]] + + whitelist_raw = tool_config.get("whitelist") + if isinstance(whitelist_raw, dict): + policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()} + + logger.debug( + f"Loaded Host Policy: {len(policy['whitelist'])} whitelisted nodes, Network={policy['allow_network']}" + ) + return policy + + +__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"] diff --git a/comfy/isolation/manifest_loader.py b/comfy/isolation/manifest_loader.py new file mode 100644 index 000000000..42007302f --- /dev/null +++ b/comfy/isolation/manifest_loader.py @@ -0,0 +1,186 @@ +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +import hashlib +import json +import logging +import os +import sys +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import folder_paths + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +LOG_PREFIX = "][" +logger = logging.getLogger(__name__) + +CACHE_SUBDIR = "cache" +CACHE_KEY_FILE = "cache_key" +CACHE_DATA_FILE = "node_info.json" +CACHE_KEY_LENGTH = 16 + + +def find_manifest_directories() -> List[Tuple[Path, Path]]: + """Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation].""" + manifest_dirs: List[Tuple[Path, Path]] = [] + + # Standard custom_nodes paths + for base_path in folder_paths.get_folder_paths("custom_nodes"): + base = Path(base_path) + if not base.exists() or not base.is_dir(): + continue + + for entry in base.iterdir(): + if not entry.is_dir(): + continue + + # Look for pyproject.toml + manifest = entry / "pyproject.toml" + if not manifest.exists(): + continue + + # Validate [tool.comfy.isolation] section existence + try: + with manifest.open("rb") as f: + data = tomllib.load(f) + + if ( + "tool" in data + and "comfy" in data["tool"] + and "isolation" in data["tool"]["comfy"] + ): + manifest_dirs.append((entry, manifest)) + + except Exception: + continue + + return manifest_dirs + + +def compute_cache_key(node_dir: Path, manifest_path: Path) -> str: + """Hash manifest + .py mtimes + Python version + PyIsolate version.""" + hasher = hashlib.sha256() + + try: + # Hashing the manifest content ensures config changes invalidate cache + hasher.update(manifest_path.read_bytes()) + except OSError: + hasher.update(b"__manifest_read_error__") + + try: + py_files = sorted(node_dir.rglob("*.py")) + for py_file in py_files: + rel_path = py_file.relative_to(node_dir) + if "__pycache__" in str(rel_path) or ".venv" in str(rel_path): + continue + hasher.update(str(rel_path).encode("utf-8")) + try: + hasher.update(str(py_file.stat().st_mtime).encode("utf-8")) + except OSError: + hasher.update(b"__file_stat_error__") + except OSError: + hasher.update(b"__dir_scan_error__") + + hasher.update(sys.version.encode("utf-8")) + + try: + import pyisolate + + hasher.update(pyisolate.__version__.encode("utf-8")) + except (ImportError, AttributeError): + hasher.update(b"__pyisolate_unknown__") + + return hasher.hexdigest()[:CACHE_KEY_LENGTH] + + +def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]: + """Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/.""" + cache_dir = venv_root / node_dir.name / CACHE_SUBDIR + return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE) + + +def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool: + """Return True only if stored cache key matches current computed key.""" + try: + cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root) + if not cache_key_file.exists() or not cache_data_file.exists(): + return False + current_key = compute_cache_key(node_dir, manifest_path) + stored_key = cache_key_file.read_text(encoding="utf-8").strip() + return current_key == stored_key + except Exception as e: + logger.debug( + "%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e + ) + return False + + +def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]: + """Load node metadata from cache, return None on any error.""" + try: + _, cache_data_file = get_cache_path(node_dir, venv_root) + if not cache_data_file.exists(): + return None + data = json.loads(cache_data_file.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return None + return data + except Exception: + return None + + +def save_to_cache( + node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path +) -> None: + """Save node metadata and cache key atomically.""" + try: + cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root) + cache_dir = cache_key_file.parent + cache_dir.mkdir(parents=True, exist_ok=True) + cache_key = compute_cache_key(node_dir, manifest_path) + + # Atomic write: data + tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp") + try: + with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f: + json.dump(node_data, f, indent=2) + os.replace(tmp_data_path, cache_data_file) + except Exception: + try: + os.unlink(tmp_data_path) + except OSError: + pass + raise + + # Atomic write: key + tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp") + try: + with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f: + f.write(cache_key) + os.replace(tmp_key_path, cache_key_file) + except Exception: + try: + os.unlink(tmp_key_path) + except OSError: + pass + raise + + except Exception as e: + logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e) + + +__all__ = [ + "LOG_PREFIX", + "find_manifest_directories", + "compute_cache_key", + "get_cache_path", + "is_cache_valid", + "load_from_cache", + "save_to_cache", +] diff --git a/comfy/isolation/rpc_bridge.py b/comfy/isolation/rpc_bridge.py new file mode 100644 index 000000000..2beb0f09f --- /dev/null +++ b/comfy/isolation/rpc_bridge.py @@ -0,0 +1,49 @@ +import asyncio +import logging +import threading + +logger = logging.getLogger(__name__) + + +class RpcBridge: + """Minimal helper to run coroutines synchronously inside isolated processes. + + If an event loop is already running, the coroutine is executed on a fresh + thread with its own loop to avoid nested run_until_complete errors. + """ + + def run_sync(self, maybe_coro): + if not asyncio.iscoroutine(maybe_coro): + return maybe_coro + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + result_container = {} + exc_container = {} + + def _runner(): + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + result_container["value"] = new_loop.run_until_complete(maybe_coro) + except Exception as exc: # pragma: no cover + exc_container["error"] = exc + finally: + try: + new_loop.close() + except Exception: + pass + + t = threading.Thread(target=_runner, daemon=True) + t.start() + t.join() + + if "error" in exc_container: + raise exc_container["error"] + return result_container.get("value") + + return asyncio.run(maybe_coro) diff --git a/comfy/isolation/runtime_helpers.py b/comfy/isolation/runtime_helpers.py new file mode 100644 index 000000000..0fe783bd0 --- /dev/null +++ b/comfy/isolation/runtime_helpers.py @@ -0,0 +1,343 @@ +# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member +from __future__ import annotations + +import copy +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Set, TYPE_CHECKING + +from .proxies.helper_proxies import restore_input_types +from comfy_api.internal import _ComfyNodeInternal +from comfy_api.latest import _io as latest_io +from .shm_forensics import scan_shm_forensics + +if TYPE_CHECKING: + from .extension_wrapper import ComfyNodeExtension + +LOG_PREFIX = "][" +_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + + +def _resource_snapshot() -> Dict[str, int]: + fd_count = -1 + shm_sender_files = 0 + try: + fd_count = len(os.listdir("/proc/self/fd")) + except Exception: + pass + try: + shm_root = Path("/dev/shm") + if shm_root.exists(): + prefix = f"torch_{os.getpid()}_" + shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*")) + except Exception: + pass + return {"fd_count": fd_count, "shm_sender_files": shm_sender_files} + + +def _tensor_transport_summary(value: Any) -> Dict[str, int]: + summary: Dict[str, int] = { + "tensor_count": 0, + "cpu_tensors": 0, + "cuda_tensors": 0, + "shared_cpu_tensors": 0, + "tensor_bytes": 0, + } + try: + import torch + except Exception: + return summary + + def visit(node: Any) -> None: + if isinstance(node, torch.Tensor): + summary["tensor_count"] += 1 + summary["tensor_bytes"] += int(node.numel() * node.element_size()) + if node.device.type == "cpu": + summary["cpu_tensors"] += 1 + if node.is_shared(): + summary["shared_cpu_tensors"] += 1 + elif node.device.type == "cuda": + summary["cuda_tensors"] += 1 + return + if isinstance(node, dict): + for v in node.values(): + visit(v) + return + if isinstance(node, (list, tuple)): + for v in node: + visit(v) + + visit(value) + return summary + + +def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None: + for key, value in inputs.items(): + key_text = str(key) + if "unique_id" in key_text: + return str(value) + return None + + +def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + return + if not callable(flush_tensor_keeper): + return + flushed = flush_tensor_keeper() + if flushed > 0: + logger.debug( + "%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed + ) + + +def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None: + import comfy.model_management as model_management + + model_management.cleanup_models_gc() + model_management.cleanup_models() + + device = model_management.get_torch_device() + if not hasattr(device, "type") or device.type == "cpu": + return + + required = max( + model_management.minimum_inference_memory(), + _PRE_EXEC_MIN_FREE_VRAM_BYTES, + ) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=True) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.cleanup_models() + model_management.soft_empty_cache() + logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required) + + +def _detach_shared_cpu_tensors(value: Any) -> Any: + try: + import torch + except Exception: + return value + + if isinstance(value, torch.Tensor): + if value.device.type == "cpu" and value.is_shared(): + clone = value.clone() + if value.requires_grad: + clone.requires_grad_(True) + return clone + return value + if isinstance(value, list): + return [_detach_shared_cpu_tensors(v) for v in value] + if isinstance(value, tuple): + return tuple(_detach_shared_cpu_tensors(v) for v in value) + if isinstance(value, dict): + return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()} + return value + + +def build_stub_class( + node_name: str, + info: Dict[str, object], + extension: "ComfyNodeExtension", + running_extensions: Dict[str, "ComfyNodeExtension"], + logger: logging.Logger, +) -> type: + is_v3 = bool(info.get("is_v3", False)) + function_name = "_pyisolate_execute" + restored_input_types = restore_input_types(info.get("input_types", {})) + + async def _execute(self, **inputs): + from comfy.isolation import _RUNNING_EXTENSIONS + + # Update BOTH the local dict AND the module-level dict + running_extensions[extension.name] = extension + _RUNNING_EXTENSIONS[extension.name] = extension + prev_child = None + node_unique_id = _extract_hidden_unique_id(inputs) + summary = _tensor_transport_summary(inputs) + resources = _resource_snapshot() + logger.debug( + "%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + summary["tensor_count"], + summary["cpu_tensors"], + summary["cuda_tensors"], + summary["shared_cpu_tensors"], + summary["tensor_bytes"], + resources["fd_count"], + resources["shm_sender_files"], + ) + scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True) + try: + if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1": + _relieve_host_vram_pressure("RUNTIME:pre_execute", logger) + scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True) + from pyisolate._internal.model_serialization import ( + serialize_for_isolation, + deserialize_from_isolation, + ) + + prev_child = os.environ.pop("PYISOLATE_CHILD", None) + logger.debug( + "%s ISO:serialize_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + serialized = serialize_for_isolation(inputs) + logger.debug( + "%s ISO:serialize_done ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + logger.debug( + "%s ISO:dispatch_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + result = await extension.execute_node(node_name, **serialized) + logger.debug( + "%s ISO:dispatch_done ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + deserialized = await deserialize_from_isolation(result, extension) + scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True) + return _detach_shared_cpu_tensors(deserialized) + except ImportError: + return await extension.execute_node(node_name, **inputs) + except Exception: + logger.exception( + "%s ISO:execute_error ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + raise + finally: + if prev_child is not None: + os.environ["PYISOLATE_CHILD"] = prev_child + logger.debug( + "%s ISO:execute_end ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True) + + def _input_types( + cls, + include_hidden: bool = True, + return_schema: bool = False, + live_inputs: Any = None, + ): + if not is_v3: + return restored_input_types + + inputs_copy = copy.deepcopy(restored_input_types) + if not include_hidden: + inputs_copy.pop("hidden", None) + + v3_data: Dict[str, Any] = {"hidden_inputs": {}} + dynamic = inputs_copy.pop("dynamic_paths", None) + if dynamic is not None: + v3_data["dynamic_paths"] = dynamic + + if return_schema: + hidden_vals = info.get("hidden", []) or [] + hidden_enums = [] + for h in hidden_vals: + try: + hidden_enums.append(latest_io.Hidden(h)) + except Exception: + hidden_enums.append(h) + + class SchemaProxy: + hidden = hidden_enums + + return inputs_copy, SchemaProxy, v3_data + return inputs_copy + + def _validate_class(cls): + return True + + def _get_node_info_v1(cls): + return info.get("schema_v1", {}) + + def _get_base_class(cls): + return latest_io.ComfyNode + + attributes: Dict[str, object] = { + "FUNCTION": function_name, + "CATEGORY": info.get("category", ""), + "OUTPUT_NODE": info.get("output_node", False), + "RETURN_TYPES": tuple(info.get("return_types", ()) or ()), + "RETURN_NAMES": info.get("return_names"), + function_name: _execute, + "_pyisolate_extension": extension, + "_pyisolate_node_name": node_name, + "INPUT_TYPES": classmethod(_input_types), + } + + output_is_list = info.get("output_is_list") + if output_is_list is not None: + attributes["OUTPUT_IS_LIST"] = tuple(output_is_list) + + if is_v3: + attributes["VALIDATE_CLASS"] = classmethod(_validate_class) + attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1) + attributes["GET_BASE_CLASS"] = classmethod(_get_base_class) + attributes["DESCRIPTION"] = info.get("description", "") + attributes["EXPERIMENTAL"] = info.get("experimental", False) + attributes["DEPRECATED"] = info.get("deprecated", False) + attributes["API_NODE"] = info.get("api_node", False) + attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False) + attributes["INPUT_IS_LIST"] = info.get("input_is_list", False) + + class_name = f"PyIsolate_{node_name}".replace(" ", "_") + bases = (_ComfyNodeInternal,) if is_v3 else () + stub_cls = type(class_name, bases, attributes) + + if is_v3: + try: + stub_cls.VALIDATE_CLASS() + except Exception as e: + logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e) + + return stub_cls + + +def get_class_types_for_extension( + extension_name: str, + running_extensions: Dict[str, "ComfyNodeExtension"], + specs: List[Any], +) -> Set[str]: + extension = running_extensions.get(extension_name) + if not extension: + return set() + + ext_path = Path(extension.module_path) + class_types = set() + for spec in specs: + if spec.module_path.resolve() == ext_path.resolve(): + class_types.add(spec.node_name) + return class_types + + +__all__ = ["build_stub_class", "get_class_types_for_extension"] diff --git a/comfy/isolation/shm_forensics.py b/comfy/isolation/shm_forensics.py new file mode 100644 index 000000000..36223505a --- /dev/null +++ b/comfy/isolation/shm_forensics.py @@ -0,0 +1,217 @@ +# pylint: disable=consider-using-from-import,import-outside-toplevel +from __future__ import annotations + +import atexit +import hashlib +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Set + +LOG_PREFIX = "][" +logger = logging.getLogger(__name__) + + +def _shm_debug_enabled() -> bool: + return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1" + + +class _SHMForensicsTracker: + def __init__(self) -> None: + self._started = False + self._tracked_files: Set[str] = set() + self._current_model_context: Dict[str, str] = { + "id": "unknown", + "name": "unknown", + "hash": "????", + } + + @staticmethod + def _snapshot_shm() -> Set[str]: + shm_path = Path("/dev/shm") + if not shm_path.exists(): + return set() + return {f.name for f in shm_path.glob("torch_*")} + + def start(self) -> None: + if self._started or not _shm_debug_enabled(): + return + self._tracked_files = self._snapshot_shm() + self._started = True + logger.debug( + "%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files) + ) + + def stop(self) -> None: + if not self._started: + return + self.scan("shutdown", refresh_model_context=True) + self._started = False + logger.debug("%s SHM:forensics_disabled", LOG_PREFIX) + + def _compute_model_hash(self, model_patcher: Any) -> str: + try: + model_instance_id = getattr(model_patcher, "_instance_id", None) + if model_instance_id is not None: + model_id_text = str(model_instance_id) + return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text + + import torch + + real_model = ( + model_patcher.model + if hasattr(model_patcher, "model") + else model_patcher + ) + tensor = None + if hasattr(real_model, "parameters"): + for p in real_model.parameters(): + if torch.is_tensor(p) and p.numel() > 0: + tensor = p + break + + if tensor is None: + return "0000" + + flat = tensor.flatten() + values = [] + indices = [0, flat.shape[0] // 2, flat.shape[0] - 1] + for i in indices: + if i < flat.shape[0]: + values.append(flat[i].item()) + + size = 0 + if hasattr(model_patcher, "model_size"): + size = model_patcher.model_size() + sample_str = f"{values}_{id(model_patcher):016x}_{size}" + return hashlib.sha256(sample_str.encode()).hexdigest()[-4:] + except Exception: + return "err!" + + def _get_models_snapshot(self) -> List[Dict[str, Any]]: + try: + import comfy.model_management as model_management + except Exception: + return [] + + snapshot: List[Dict[str, Any]] = [] + try: + for loaded_model in model_management.current_loaded_models: + model = loaded_model.model + if model is None: + continue + if str(getattr(loaded_model, "device", "")) != "cuda:0": + continue + + name = ( + model.model.__class__.__name__ + if hasattr(model, "model") + else type(model).__name__ + ) + model_hash = self._compute_model_hash(model) + model_instance_id = getattr(model, "_instance_id", None) + if model_instance_id is None: + model_instance_id = model_hash + snapshot.append( + { + "name": str(name), + "id": str(model_instance_id), + "hash": str(model_hash or "????"), + "used": bool(getattr(loaded_model, "currently_used", False)), + } + ) + except Exception: + return [] + + return snapshot + + def _update_model_context(self) -> None: + snapshot = self._get_models_snapshot() + selected = None + + used_models = [m for m in snapshot if m.get("used") and m.get("id")] + if used_models: + selected = used_models[-1] + else: + live_models = [m for m in snapshot if m.get("id")] + if live_models: + selected = live_models[-1] + + if selected is None: + self._current_model_context = { + "id": "unknown", + "name": "unknown", + "hash": "????", + } + return + + self._current_model_context = { + "id": str(selected.get("id", "unknown")), + "name": str(selected.get("name", "unknown")), + "hash": str(selected.get("hash", "????") or "????"), + } + + def scan(self, marker: str, refresh_model_context: bool = True) -> None: + if not self._started or not _shm_debug_enabled(): + return + + if refresh_model_context: + self._update_model_context() + + current = self._snapshot_shm() + added = current - self._tracked_files + removed = self._tracked_files - current + self._tracked_files = current + + if not added and not removed: + logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker) + return + + for filename in sorted(added): + logger.info("%s SHM:created | %s", LOG_PREFIX, filename) + model_id = self._current_model_context["id"] + if model_id == "unknown": + logger.error( + "%s SHM:model_association_missing | file=%s | reason=no_active_model_context", + LOG_PREFIX, + filename, + ) + else: + logger.info( + "%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s", + LOG_PREFIX, + model_id, + filename, + self._current_model_context["name"], + self._current_model_context["hash"], + ) + + for filename in sorted(removed): + logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename) + + logger.debug( + "%s SHM:scan marker=%s created=%d deleted=%d active=%d", + LOG_PREFIX, + marker, + len(added), + len(removed), + len(self._tracked_files), + ) + + +_TRACKER = _SHMForensicsTracker() + + +def start_shm_forensics() -> None: + _TRACKER.start() + + +def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None: + _TRACKER.scan(marker, refresh_model_context=refresh_model_context) + + +def stop_shm_forensics() -> None: + _TRACKER.stop() + + +atexit.register(stop_shm_forensics)