mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 18:01:25 +00:00
feat(isolation): process isolation for custom nodes via pyisolate
Adds opt-in process isolation for custom nodes using pyisolate's bwrap sandbox and JSON-RPC bridge. Each isolated node pack runs in its own child process with zero-copy tensor transfer via shared memory. Core infrastructure: - CLI flag --use-process-isolation to enable isolation - Host/child startup fencing via PYISOLATE_CHILD env var - Manifest-driven node discovery and extension loading - JSON-RPC bridge between host and child processes - Shared memory forensics for leak detection Proxy layer: - ModelPatcher, CLIP, VAE, and ModelSampling proxies - Host service proxies (folder_paths, model_management, progress, etc.) - Proxy base with automatic method forwarding Execution integration: - Extension wrapper with V3 hidden param mapping - Runtime helpers for isolated node execution - Host policy for node isolation decisions - Fenced sampler device handling and model ejection parity Serializers for cross-process data transfer: - File3D (GLB), PLY (structured + gaussian), NPZ (streaming frames), VIDEO (VideoFromFile + VideoFromComponents) serializers - data_type flag in SerializerRegistry for type-aware dispatch - Isolated get_temp_directory() fence New core save nodes: - SavePLY and SaveNPZ with comfytype registrations (Ply, Npz) DynamicVRAM compatibility: - comfy-aimdo early init gated by isolation fence Tests: - Integration and policy tests for isolation lifecycle - Manifest loader, host policy, proxy, and adapter unit tests Depends on: pyisolate >= 0.9.2
This commit is contained in:
343
comfy/isolation/runtime_helpers.py
Normal file
343
comfy/isolation/runtime_helpers.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, TYPE_CHECKING
|
||||
|
||||
from .proxies.helper_proxies import restore_input_types
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from comfy_api.latest import _io as latest_io
|
||||
from .shm_forensics import scan_shm_forensics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
def _resource_snapshot() -> Dict[str, int]:
|
||||
fd_count = -1
|
||||
shm_sender_files = 0
|
||||
try:
|
||||
fd_count = len(os.listdir("/proc/self/fd"))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
shm_root = Path("/dev/shm")
|
||||
if shm_root.exists():
|
||||
prefix = f"torch_{os.getpid()}_"
|
||||
shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*"))
|
||||
except Exception:
|
||||
pass
|
||||
return {"fd_count": fd_count, "shm_sender_files": shm_sender_files}
|
||||
|
||||
|
||||
def _tensor_transport_summary(value: Any) -> Dict[str, int]:
|
||||
summary: Dict[str, int] = {
|
||||
"tensor_count": 0,
|
||||
"cpu_tensors": 0,
|
||||
"cuda_tensors": 0,
|
||||
"shared_cpu_tensors": 0,
|
||||
"tensor_bytes": 0,
|
||||
}
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return summary
|
||||
|
||||
def visit(node: Any) -> None:
|
||||
if isinstance(node, torch.Tensor):
|
||||
summary["tensor_count"] += 1
|
||||
summary["tensor_bytes"] += int(node.numel() * node.element_size())
|
||||
if node.device.type == "cpu":
|
||||
summary["cpu_tensors"] += 1
|
||||
if node.is_shared():
|
||||
summary["shared_cpu_tensors"] += 1
|
||||
elif node.device.type == "cuda":
|
||||
summary["cuda_tensors"] += 1
|
||||
return
|
||||
if isinstance(node, dict):
|
||||
for v in node.values():
|
||||
visit(v)
|
||||
return
|
||||
if isinstance(node, (list, tuple)):
|
||||
for v in node:
|
||||
visit(v)
|
||||
|
||||
visit(value)
|
||||
return summary
|
||||
|
||||
|
||||
def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None:
|
||||
for key, value in inputs.items():
|
||||
key_text = str(key)
|
||||
if "unique_id" in key_text:
|
||||
return str(value)
|
||||
return None
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return
|
||||
if not callable(flush_tensor_keeper):
|
||||
return
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
|
||||
|
||||
def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _detach_shared_cpu_tensors(value: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.device.type == "cpu" and value.is_shared():
|
||||
clone = value.clone()
|
||||
if value.requires_grad:
|
||||
clone.requires_grad_(True)
|
||||
return clone
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
return [_detach_shared_cpu_tensors(v) for v in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_detach_shared_cpu_tensors(v) for v in value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def build_stub_class(
|
||||
node_name: str,
|
||||
info: Dict[str, object],
|
||||
extension: "ComfyNodeExtension",
|
||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||
logger: logging.Logger,
|
||||
) -> type:
|
||||
is_v3 = bool(info.get("is_v3", False))
|
||||
function_name = "_pyisolate_execute"
|
||||
restored_input_types = restore_input_types(info.get("input_types", {}))
|
||||
|
||||
async def _execute(self, **inputs):
|
||||
from comfy.isolation import _RUNNING_EXTENSIONS
|
||||
|
||||
# Update BOTH the local dict AND the module-level dict
|
||||
running_extensions[extension.name] = extension
|
||||
_RUNNING_EXTENSIONS[extension.name] = extension
|
||||
prev_child = None
|
||||
node_unique_id = _extract_hidden_unique_id(inputs)
|
||||
summary = _tensor_transport_summary(inputs)
|
||||
resources = _resource_snapshot()
|
||||
logger.debug(
|
||||
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
summary["tensor_count"],
|
||||
summary["cpu_tensors"],
|
||||
summary["cuda_tensors"],
|
||||
summary["shared_cpu_tensors"],
|
||||
summary["tensor_bytes"],
|
||||
resources["fd_count"],
|
||||
resources["shm_sender_files"],
|
||||
)
|
||||
scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True)
|
||||
try:
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
_relieve_host_vram_pressure("RUNTIME:pre_execute", logger)
|
||||
scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True)
|
||||
from pyisolate._internal.model_serialization import (
|
||||
serialize_for_isolation,
|
||||
deserialize_from_isolation,
|
||||
)
|
||||
|
||||
prev_child = os.environ.pop("PYISOLATE_CHILD", None)
|
||||
logger.debug(
|
||||
"%s ISO:serialize_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
serialized = serialize_for_isolation(inputs)
|
||||
logger.debug(
|
||||
"%s ISO:serialize_done ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:dispatch_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
result = await extension.execute_node(node_name, **serialized)
|
||||
logger.debug(
|
||||
"%s ISO:dispatch_done ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
deserialized = await deserialize_from_isolation(result, extension)
|
||||
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||
return _detach_shared_cpu_tensors(deserialized)
|
||||
except ImportError:
|
||||
return await extension.execute_node(node_name, **inputs)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:execute_error ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if prev_child is not None:
|
||||
os.environ["PYISOLATE_CHILD"] = prev_child
|
||||
logger.debug(
|
||||
"%s ISO:execute_end ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True)
|
||||
|
||||
def _input_types(
|
||||
cls,
|
||||
include_hidden: bool = True,
|
||||
return_schema: bool = False,
|
||||
live_inputs: Any = None,
|
||||
):
|
||||
if not is_v3:
|
||||
return restored_input_types
|
||||
|
||||
inputs_copy = copy.deepcopy(restored_input_types)
|
||||
if not include_hidden:
|
||||
inputs_copy.pop("hidden", None)
|
||||
|
||||
v3_data: Dict[str, Any] = {"hidden_inputs": {}}
|
||||
dynamic = inputs_copy.pop("dynamic_paths", None)
|
||||
if dynamic is not None:
|
||||
v3_data["dynamic_paths"] = dynamic
|
||||
|
||||
if return_schema:
|
||||
hidden_vals = info.get("hidden", []) or []
|
||||
hidden_enums = []
|
||||
for h in hidden_vals:
|
||||
try:
|
||||
hidden_enums.append(latest_io.Hidden(h))
|
||||
except Exception:
|
||||
hidden_enums.append(h)
|
||||
|
||||
class SchemaProxy:
|
||||
hidden = hidden_enums
|
||||
|
||||
return inputs_copy, SchemaProxy, v3_data
|
||||
return inputs_copy
|
||||
|
||||
def _validate_class(cls):
|
||||
return True
|
||||
|
||||
def _get_node_info_v1(cls):
|
||||
return info.get("schema_v1", {})
|
||||
|
||||
def _get_base_class(cls):
|
||||
return latest_io.ComfyNode
|
||||
|
||||
attributes: Dict[str, object] = {
|
||||
"FUNCTION": function_name,
|
||||
"CATEGORY": info.get("category", ""),
|
||||
"OUTPUT_NODE": info.get("output_node", False),
|
||||
"RETURN_TYPES": tuple(info.get("return_types", ()) or ()),
|
||||
"RETURN_NAMES": info.get("return_names"),
|
||||
function_name: _execute,
|
||||
"_pyisolate_extension": extension,
|
||||
"_pyisolate_node_name": node_name,
|
||||
"INPUT_TYPES": classmethod(_input_types),
|
||||
}
|
||||
|
||||
output_is_list = info.get("output_is_list")
|
||||
if output_is_list is not None:
|
||||
attributes["OUTPUT_IS_LIST"] = tuple(output_is_list)
|
||||
|
||||
if is_v3:
|
||||
attributes["VALIDATE_CLASS"] = classmethod(_validate_class)
|
||||
attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1)
|
||||
attributes["GET_BASE_CLASS"] = classmethod(_get_base_class)
|
||||
attributes["DESCRIPTION"] = info.get("description", "")
|
||||
attributes["EXPERIMENTAL"] = info.get("experimental", False)
|
||||
attributes["DEPRECATED"] = info.get("deprecated", False)
|
||||
attributes["API_NODE"] = info.get("api_node", False)
|
||||
attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False)
|
||||
attributes["INPUT_IS_LIST"] = info.get("input_is_list", False)
|
||||
|
||||
class_name = f"PyIsolate_{node_name}".replace(" ", "_")
|
||||
bases = (_ComfyNodeInternal,) if is_v3 else ()
|
||||
stub_cls = type(class_name, bases, attributes)
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
stub_cls.VALIDATE_CLASS()
|
||||
except Exception as e:
|
||||
logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e)
|
||||
|
||||
return stub_cls
|
||||
|
||||
|
||||
def get_class_types_for_extension(
|
||||
extension_name: str,
|
||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||
specs: List[Any],
|
||||
) -> Set[str]:
|
||||
extension = running_extensions.get(extension_name)
|
||||
if not extension:
|
||||
return set()
|
||||
|
||||
ext_path = Path(extension.module_path)
|
||||
class_types = set()
|
||||
for spec in specs:
|
||||
if spec.module_path.resolve() == ext_path.resolve():
|
||||
class_types.add(spec.node_name)
|
||||
return class_types
|
||||
|
||||
|
||||
__all__ = ["build_stub_class", "get_class_types_for_extension"]
|
||||
Reference in New Issue
Block a user