# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name from __future__ import annotations import logging import os import inspect import sys import types import platform from pathlib import Path from typing import Callable, Dict, List, Tuple import pyisolate from pyisolate import ExtensionManager, ExtensionManagerConfig from packaging.requirements import InvalidRequirement, Requirement from packaging.utils import canonicalize_name from .extension_wrapper import ComfyNodeExtension from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache from .host_policy import load_host_policy try: import tomllib except ImportError: import tomli as tomllib # type: ignore[no-redef] logger = logging.getLogger(__name__) async def _stop_extension_safe( extension: ComfyNodeExtension, extension_name: str ) -> None: try: stop_result = extension.stop() if inspect.isawaitable(stop_result): await stop_result except Exception: logger.debug("][ %s stop failed", extension_name, exc_info=True) def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str: req, sep, marker = dep.partition(";") req = req.strip() marker_suffix = f";{marker}" if sep else "" def _resolve_local_path(local_path: str) -> Path | None: for base in base_paths: candidate = (base / local_path).resolve() if candidate.exists(): return candidate return None if req.startswith("./") or req.startswith("../"): resolved = _resolve_local_path(req) if resolved is not None: return f"{resolved}{marker_suffix}" if req.startswith("file://"): raw = req[len("file://") :] if raw.startswith("./") or raw.startswith("../"): resolved = _resolve_local_path(raw) if resolved is not None: return f"file://{resolved}{marker_suffix}" return dep def _dependency_name_from_spec(dep: str) -> str | None: stripped = dep.strip() if not stripped or stripped == "-e" or stripped.startswith("-e "): return None if stripped.startswith(("/", "./", "../", "file://")): return None try: return canonicalize_name(Requirement(stripped).name) except InvalidRequirement: return None def _parse_cuda_wheels_config( tool_config: dict[str, object], dependencies: list[str] ) -> dict[str, object] | None: raw_config = tool_config.get("cuda_wheels") if raw_config is None: return None if not isinstance(raw_config, dict): raise ExtensionLoadError( "[tool.comfy.isolation.cuda_wheels] must be a table" ) index_url = raw_config.get("index_url") if not isinstance(index_url, str) or not index_url.strip(): raise ExtensionLoadError( "[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string" ) packages = raw_config.get("packages") if not isinstance(packages, list) or not all( isinstance(package_name, str) and package_name.strip() for package_name in packages ): raise ExtensionLoadError( "[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings" ) declared_dependencies = { dependency_name for dep in dependencies if (dependency_name := _dependency_name_from_spec(dep)) is not None } normalized_packages = [canonicalize_name(package_name) for package_name in packages] missing = [ package_name for package_name in normalized_packages if package_name not in declared_dependencies ] if missing: missing_joined = ", ".join(sorted(missing)) raise ExtensionLoadError( "[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: " f"{missing_joined}" ) package_map = raw_config.get("package_map", {}) if not isinstance(package_map, dict): raise ExtensionLoadError( "[tool.comfy.isolation.cuda_wheels.package_map] must be a table" ) normalized_package_map: dict[str, str] = {} for dependency_name, index_package_name in package_map.items(): if not isinstance(dependency_name, str) or not dependency_name.strip(): raise ExtensionLoadError( "[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings" ) if not isinstance(index_package_name, str) or not index_package_name.strip(): raise ExtensionLoadError( "[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings" ) canonical_dependency_name = canonicalize_name(dependency_name) if canonical_dependency_name not in normalized_packages: raise ExtensionLoadError( "[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in " "[tool.comfy.isolation.cuda_wheels.packages]" ) normalized_package_map[canonical_dependency_name] = index_package_name.strip() return { "index_url": index_url.rstrip("/") + "/", "packages": normalized_packages, "package_map": normalized_package_map, } def get_enforcement_policy() -> Dict[str, bool]: return { "force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1", "force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1", } class ExtensionLoadError(RuntimeError): pass def register_dummy_module(extension_name: str, node_dir: Path) -> None: normalized_name = extension_name.replace("-", "_").replace(".", "_") if normalized_name not in sys.modules: dummy_module = types.ModuleType(normalized_name) dummy_module.__file__ = str(node_dir / "__init__.py") dummy_module.__path__ = [str(node_dir)] dummy_module.__package__ = normalized_name sys.modules[normalized_name] = dummy_module def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool: for details in cached_data.values(): if not isinstance(details, dict): return True if details.get("is_v3") and "schema_v1" not in details: return True return False async def load_isolated_node( node_dir: Path, manifest_path: Path, logger: logging.Logger, build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type], venv_root: Path, extension_managers: List[ExtensionManager], ) -> List[Tuple[str, str, type]]: try: with manifest_path.open("rb") as handle: manifest_data = tomllib.load(handle) except Exception as e: logger.warning(f"][ Failed to parse {manifest_path}: {e}") return [] # Parse [tool.comfy.isolation] tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {}) can_isolate = tool_config.get("can_isolate", False) share_torch = tool_config.get("share_torch", False) # Parse [project] dependencies project_config = manifest_data.get("project", {}) dependencies = project_config.get("dependencies", []) if not isinstance(dependencies, list): dependencies = [] # Get extension name (default to folder name if not in project.name) extension_name = project_config.get("name", node_dir.name) # LOGIC: Isolation Decision policy = get_enforcement_policy() isolated = can_isolate or policy["force_isolated"] if not isolated: return [] logger.info(f"][ Loading isolated node: {extension_name}") import folder_paths base_paths = [Path(folder_paths.base_path), node_dir] dependencies = [ _normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep for dep in dependencies ] cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies) manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root)) manager: ExtensionManager = pyisolate.ExtensionManager( ComfyNodeExtension, manager_config ) extension_managers.append(manager) host_policy = load_host_policy(Path(folder_paths.base_path)) sandbox_config = {} is_linux = platform.system() == "Linux" if is_linux and isolated: sandbox_config = { "network": host_policy["allow_network"], "writable_paths": host_policy["writable_paths"], "readonly_paths": host_policy["readonly_paths"], } share_cuda_ipc = share_torch and is_linux extension_config = { "name": extension_name, "module_path": str(node_dir), "isolated": True, "dependencies": dependencies, "share_torch": share_torch, "share_cuda_ipc": share_cuda_ipc, "sandbox": sandbox_config, } if cuda_wheels is not None: extension_config["cuda_wheels"] = cuda_wheels extension = manager.load_extension(extension_config) register_dummy_module(extension_name, node_dir) # Try cache first (lazy spawn) if is_cache_valid(node_dir, manifest_path, venv_root): cached_data = load_from_cache(node_dir, venv_root) if cached_data: if _is_stale_node_cache(cached_data): logger.debug( "][ %s cache is stale/incompatible; rebuilding metadata", extension_name, ) else: logger.debug(f"][ {extension_name} loaded from cache") specs: List[Tuple[str, str, type]] = [] for node_name, details in cached_data.items(): stub_cls = build_stub_class(node_name, details, extension) specs.append( (node_name, details.get("display_name", node_name), stub_cls) ) return specs # Cache miss - spawn process and get metadata logger.debug(f"][ {extension_name} cache miss, spawning process for metadata") try: remote_nodes: Dict[str, str] = await extension.list_nodes() except Exception as exc: logger.warning( "][ %s metadata discovery failed, skipping isolated load: %s", extension_name, exc, ) await _stop_extension_safe(extension, extension_name) return [] if not remote_nodes: logger.debug("][ %s exposed no isolated nodes; skipping", extension_name) await _stop_extension_safe(extension, extension_name) return [] specs: List[Tuple[str, str, type]] = [] cache_data: Dict[str, Dict] = {} for node_name, display_name in remote_nodes.items(): try: details = await extension.get_node_details(node_name) except Exception as exc: logger.warning( "][ %s failed to load metadata for %s, skipping node: %s", extension_name, node_name, exc, ) continue details["display_name"] = display_name cache_data[node_name] = details stub_cls = build_stub_class(node_name, details, extension) specs.append((node_name, display_name, stub_cls)) if not specs: logger.warning( "][ %s produced no usable nodes after metadata scan; skipping", extension_name, ) await _stop_extension_safe(extension, extension_name) return [] # Save metadata to cache for future runs save_to_cache(node_dir, venv_root, cache_data, manifest_path) logger.debug(f"][ {extension_name} metadata cached") # EJECT: Kill process after getting metadata (will respawn on first execution) await _stop_extension_safe(extension, extension_name) return specs __all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]