mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-02 02:39:06 +00:00
Compare commits
22 Commits
v0.17.2
...
pyisolate-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82af45530d | ||
|
|
9e3e939db1 | ||
|
|
b11129e169 | ||
|
|
a6b5e6545d | ||
|
|
d90e28863e | ||
|
|
683e2d6a73 | ||
|
|
878684d8b2 | ||
|
|
c02372936d | ||
|
|
6aa0b838a0 | ||
|
|
54461f9ecc | ||
|
|
b602cc4533 | ||
|
|
08b92a48c3 | ||
|
|
c5e7b9cdaf | ||
|
|
623a9d21e9 | ||
|
|
9250191c65 | ||
|
|
a0f8784e9f | ||
|
|
7962db477a | ||
|
|
3c8ba051b6 | ||
|
|
a1c3124821 | ||
|
|
9ca799362d | ||
|
|
22f5e43c12 | ||
|
|
3cfd5e3311 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -24,3 +24,4 @@ web_custom_versions/
|
||||
openapi.yaml
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
.pyisolate_venvs/
|
||||
|
||||
@@ -179,6 +179,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
|
||||
@@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
gradient_stops: NotRequired[list[dict]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
|
||||
gradient_stops: NotRequired[list[list[float]]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
|
||||
@@ -14,6 +14,9 @@ if TYPE_CHECKING:
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from comfy.cli_args import args
|
||||
import uuid
|
||||
import os
|
||||
from node_helpers import conditioning_set_values
|
||||
|
||||
# #######################################################################################################
|
||||
@@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
|
||||
HookedOnly = "hooked_only"
|
||||
|
||||
|
||||
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
class _HookRef:
|
||||
pass
|
||||
def __init__(self):
|
||||
if _ISOLATION_HOOKREF_MODE:
|
||||
self._pyisolate_id = str(uuid.uuid4())
|
||||
|
||||
def _ensure_pyisolate_id(self):
|
||||
pyisolate_id = getattr(self, "_pyisolate_id", None)
|
||||
if pyisolate_id is None:
|
||||
pyisolate_id = str(uuid.uuid4())
|
||||
self._pyisolate_id = pyisolate_id
|
||||
return pyisolate_id
|
||||
|
||||
def __eq__(self, other):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return self is other
|
||||
if not isinstance(other, _HookRef):
|
||||
return False
|
||||
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
|
||||
|
||||
def __hash__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return id(self)
|
||||
return hash(self._ensure_pyisolate_id())
|
||||
|
||||
def __str__(self):
|
||||
if not _ISOLATION_HOOKREF_MODE:
|
||||
return super().__str__()
|
||||
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
|
||||
|
||||
|
||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||
@@ -168,6 +200,8 @@ class WeightHook(Hook):
|
||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||
else:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if self.weights is None:
|
||||
self.weights = {}
|
||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||
else:
|
||||
if target == EnumWeightTarget.Clip:
|
||||
|
||||
436
comfy/isolation/__init__.py
Normal file
436
comfy/isolation/__init__.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
||||
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||
|
||||
load_isolated_node = None
|
||||
find_manifest_directories = None
|
||||
build_stub_class = None
|
||||
get_class_types_for_extension = None
|
||||
scan_shm_forensics = None
|
||||
start_shm_forensics = None
|
||||
|
||||
if _IMPORT_TORCH:
|
||||
import folder_paths
|
||||
from .extension_loader import load_isolated_node
|
||||
from .manifest_loader import find_manifest_directories
|
||||
from .runtime_helpers import build_stub_class, get_class_types_for_extension
|
||||
from .shm_forensics import scan_shm_forensics, start_shm_forensics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyisolate import ExtensionManager
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
isolated_node_timings: List[tuple[float, Path, int]] = []
|
||||
|
||||
if _IMPORT_TORCH:
|
||||
PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs"
|
||||
PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
|
||||
|
||||
|
||||
def initialize_proxies() -> None:
|
||||
from .child_hooks import is_child_process
|
||||
|
||||
is_child = is_child_process()
|
||||
|
||||
if is_child:
|
||||
from .child_hooks import initialize_child_process
|
||||
|
||||
initialize_child_process()
|
||||
else:
|
||||
from .host_hooks import initialize_host_process
|
||||
|
||||
initialize_host_process()
|
||||
if start_shm_forensics is not None:
|
||||
start_shm_forensics()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IsolatedNodeSpec:
|
||||
node_name: str
|
||||
display_name: str
|
||||
stub_class: type
|
||||
module_path: Path
|
||||
|
||||
|
||||
_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = []
|
||||
_CLAIMED_PATHS: Set[Path] = set()
|
||||
_ISOLATION_SCAN_ATTEMPTED = False
|
||||
_EXTENSION_MANAGERS: List["ExtensionManager"] = []
|
||||
_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {}
|
||||
_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None
|
||||
_EARLY_START_TIME: Optional[float] = None
|
||||
|
||||
|
||||
def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||
return
|
||||
_EARLY_START_TIME = time.perf_counter()
|
||||
_ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes())
|
||||
|
||||
|
||||
async def await_isolation_loading() -> List[IsolatedNodeSpec]:
|
||||
global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME
|
||||
if _ISOLATION_BACKGROUND_TASK is not None:
|
||||
specs = await _ISOLATION_BACKGROUND_TASK
|
||||
return specs
|
||||
return await initialize_isolation_nodes()
|
||||
|
||||
|
||||
async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]:
|
||||
global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS
|
||||
|
||||
if _ISOLATED_NODE_SPECS:
|
||||
return _ISOLATED_NODE_SPECS
|
||||
|
||||
if _ISOLATION_SCAN_ATTEMPTED:
|
||||
return []
|
||||
|
||||
_ISOLATION_SCAN_ATTEMPTED = True
|
||||
if find_manifest_directories is None or load_isolated_node is None or build_stub_class is None:
|
||||
return []
|
||||
manifest_entries = find_manifest_directories()
|
||||
_CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries}
|
||||
|
||||
if not manifest_entries:
|
||||
return []
|
||||
|
||||
os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1"
|
||||
concurrency_limit = max(1, (os.cpu_count() or 4) // 2)
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def load_with_semaphore(
|
||||
node_dir: Path, manifest: Path
|
||||
) -> List[IsolatedNodeSpec]:
|
||||
async with semaphore:
|
||||
load_start = time.perf_counter()
|
||||
spec_list = await load_isolated_node(
|
||||
node_dir,
|
||||
manifest,
|
||||
logger,
|
||||
lambda name, info, extension: build_stub_class(
|
||||
name,
|
||||
info,
|
||||
extension,
|
||||
_RUNNING_EXTENSIONS,
|
||||
logger,
|
||||
),
|
||||
PYISOLATE_VENV_ROOT,
|
||||
_EXTENSION_MANAGERS,
|
||||
)
|
||||
spec_list = [
|
||||
IsolatedNodeSpec(
|
||||
node_name=node_name,
|
||||
display_name=display_name,
|
||||
stub_class=stub_cls,
|
||||
module_path=node_dir,
|
||||
)
|
||||
for node_name, display_name, stub_cls in spec_list
|
||||
]
|
||||
isolated_node_timings.append(
|
||||
(time.perf_counter() - load_start, node_dir, len(spec_list))
|
||||
)
|
||||
return spec_list
|
||||
|
||||
tasks = [
|
||||
load_with_semaphore(node_dir, manifest)
|
||||
for node_dir, manifest in manifest_entries
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
specs: List[IsolatedNodeSpec] = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
"%s Isolated node failed during startup; continuing: %s",
|
||||
LOG_PREFIX,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
specs.extend(result)
|
||||
|
||||
_ISOLATED_NODE_SPECS = specs
|
||||
return list(_ISOLATED_NODE_SPECS)
|
||||
|
||||
|
||||
def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
||||
"""Get all node class types (node names) belonging to an extension."""
|
||||
extension = _RUNNING_EXTENSIONS.get(extension_name)
|
||||
if not extension:
|
||||
return set()
|
||||
|
||||
ext_path = Path(extension.module_path)
|
||||
class_types = set()
|
||||
for spec in _ISOLATED_NODE_SPECS:
|
||||
if spec.module_path.resolve() == ext_path.resolve():
|
||||
class_types.add(spec.node_name)
|
||||
|
||||
return class_types
|
||||
|
||||
|
||||
async def notify_execution_graph(needed_class_types: Set[str], caches: list | None = None) -> None:
|
||||
"""Evict running extensions not needed for current execution.
|
||||
|
||||
When *caches* is provided, cache entries for evicted extensions' node
|
||||
class_types are invalidated to prevent stale ``RemoteObjectHandle``
|
||||
references from surviving in the output cache.
|
||||
"""
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:notify_graph_wait_idle",
|
||||
)
|
||||
|
||||
evicted_class_types: Set[str] = set()
|
||||
|
||||
async def _stop_extension(
|
||||
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
||||
) -> None:
|
||||
# Collect class_types BEFORE stopping so we can invalidate cache entries.
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
evicted_class_types.update(ext_class_types)
|
||||
logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason)
|
||||
logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name)
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
_RUNNING_EXTENSIONS.pop(ext_name, None)
|
||||
logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name)
|
||||
if scan_shm_forensics is not None:
|
||||
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
||||
|
||||
if scan_shm_forensics is not None:
|
||||
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
||||
isolated_class_types_in_graph = needed_class_types.intersection(
|
||||
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
|
||||
)
|
||||
graph_uses_isolation = bool(isolated_class_types_in_graph)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_start running=%d needed=%d",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
len(needed_class_types),
|
||||
)
|
||||
if graph_uses_isolation:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
|
||||
# If NONE of this extension's nodes are in the execution graph -> evict.
|
||||
if not ext_class_types.intersection(needed_class_types):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
"isolated custom_node not in execution graph, evicting",
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
)
|
||||
|
||||
# Isolated child processes add steady VRAM pressure; reclaim host-side models
|
||||
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
|
||||
try:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if getattr(device, "type", None) == "cuda":
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
free_before = model_management.get_free_memory(device)
|
||||
if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
f"boundary low-vram restart (free={int(free_before)} target={int(required)})",
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.unload_all_models()
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.soft_empty_cache()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
finally:
|
||||
# Invalidate cached outputs for evicted extensions so stale
|
||||
# RemoteObjectHandle references are not served from cache.
|
||||
if evicted_class_types and caches:
|
||||
total_invalidated = 0
|
||||
for cache in caches:
|
||||
if hasattr(cache, "invalidate_by_class_types"):
|
||||
total_invalidated += cache.invalidate_by_class_types(
|
||||
evicted_class_types
|
||||
)
|
||||
if total_invalidated > 0:
|
||||
logger.info(
|
||||
"%s ISO:cache_invalidated count=%d class_types=%s",
|
||||
LOG_PREFIX,
|
||||
total_invalidated,
|
||||
evicted_class_types,
|
||||
)
|
||||
scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS)
|
||||
)
|
||||
|
||||
|
||||
async def flush_running_extensions_transport_state() -> int:
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:flush_transport_wait_idle",
|
||||
)
|
||||
total_flushed = 0
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
flush_fn = getattr(extension, "flush_transport_state", None)
|
||||
if not callable(flush_fn):
|
||||
continue
|
||||
try:
|
||||
flushed = await flush_fn()
|
||||
if isinstance(flushed, int):
|
||||
total_flushed += flushed
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s workflow-end flush released=%d",
|
||||
LOG_PREFIX,
|
||||
ext_name,
|
||||
flushed,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True
|
||||
)
|
||||
scan_shm_forensics(
|
||||
"ISO:flush_running_extensions_transport_state", refresh_model_context=True
|
||||
)
|
||||
return total_flushed
|
||||
|
||||
|
||||
async def wait_for_model_patcher_quiescence(
|
||||
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
*,
|
||||
fail_loud: bool = False,
|
||||
marker: str = "ISO:wait_model_patcher_idle",
|
||||
) -> bool:
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
start = time.perf_counter()
|
||||
idle = await registry.wait_all_idle(timeout_ms)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
if idle:
|
||||
logger.debug(
|
||||
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
)
|
||||
return True
|
||||
|
||||
states = await registry.get_all_operation_states()
|
||||
logger.error(
|
||||
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
states,
|
||||
)
|
||||
if fail_loud:
|
||||
raise TimeoutError(
|
||||
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_claimed_paths() -> Set[Path]:
|
||||
return _CLAIMED_PATHS
|
||||
|
||||
|
||||
def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None:
|
||||
"""Update all active RPC instances with the current event loop.
|
||||
|
||||
This MUST be called at the start of each workflow execution to ensure
|
||||
RPC calls are scheduled on the correct event loop. This handles the case
|
||||
where asyncio.run() creates a new event loop for each workflow.
|
||||
|
||||
Args:
|
||||
loop: The event loop to use. If None, uses asyncio.get_running_loop().
|
||||
"""
|
||||
if loop is None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
update_count = 0
|
||||
|
||||
# Update RPCs from ExtensionManagers
|
||||
for manager in _EXTENSION_MANAGERS:
|
||||
if not hasattr(manager, "extensions"):
|
||||
continue
|
||||
for name, extension in manager.extensions.items():
|
||||
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||
if hasattr(extension.rpc, "update_event_loop"):
|
||||
extension.rpc.update_event_loop(loop)
|
||||
update_count += 1
|
||||
logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'")
|
||||
|
||||
# Also update RPCs from running extensions (they may have direct RPC refs)
|
||||
for name, extension in _RUNNING_EXTENSIONS.items():
|
||||
if hasattr(extension, "rpc") and extension.rpc is not None:
|
||||
if hasattr(extension.rpc, "update_event_loop"):
|
||||
extension.rpc.update_event_loop(loop)
|
||||
update_count += 1
|
||||
logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'")
|
||||
|
||||
if update_count > 0:
|
||||
logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances")
|
||||
else:
|
||||
logger.debug(
|
||||
f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LOG_PREFIX",
|
||||
"initialize_proxies",
|
||||
"initialize_isolation_nodes",
|
||||
"start_isolation_loading_early",
|
||||
"await_isolation_loading",
|
||||
"notify_execution_graph",
|
||||
"flush_running_extensions_transport_state",
|
||||
"wait_for_model_patcher_quiescence",
|
||||
"get_claimed_paths",
|
||||
"update_rpc_event_loops",
|
||||
"IsolatedNodeSpec",
|
||||
"get_class_types_for_extension",
|
||||
]
|
||||
965
comfy/isolation/adapter.py
Normal file
965
comfy/isolation/adapter.py
Normal file
@@ -0,0 +1,965 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
|
||||
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
|
||||
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
|
||||
|
||||
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||
|
||||
# Singleton proxies that do NOT transitively import torch/PIL/psutil/aiohttp.
|
||||
# Safe to import in sealed workers without host framework modules.
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.helper_proxies import HelperProxiesService
|
||||
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
|
||||
|
||||
# Singleton proxies that transitively import torch, PIL, or heavy host modules.
|
||||
# Only available when torch/host framework is present.
|
||||
CLIPProxy = None
|
||||
CLIPRegistry = None
|
||||
ModelPatcherProxy = None
|
||||
ModelPatcherRegistry = None
|
||||
ModelSamplingProxy = None
|
||||
ModelSamplingRegistry = None
|
||||
VAEProxy = None
|
||||
VAERegistry = None
|
||||
FirstStageModelRegistry = None
|
||||
ModelManagementProxy = None
|
||||
PromptServerService = None
|
||||
ProgressProxy = None
|
||||
UtilsProxy = None
|
||||
_HAS_TORCH_PROXIES = False
|
||||
if _IMPORT_TORCH:
|
||||
from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry
|
||||
from comfy.isolation.model_patcher_proxy import (
|
||||
ModelPatcherProxy,
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
from comfy.isolation.model_sampling_proxy import (
|
||||
ModelSamplingProxy,
|
||||
ModelSamplingRegistry,
|
||||
)
|
||||
from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry
|
||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerService
|
||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
_HAS_TORCH_PROXIES = True
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force /dev/shm for shared memory (bwrap makes /tmp private)
|
||||
import tempfile
|
||||
|
||||
if os.path.exists("/dev/shm"):
|
||||
# Only override if not already set or if default is not /dev/shm
|
||||
current_tmp = tempfile.gettempdir()
|
||||
if not current_tmp.startswith("/dev/shm"):
|
||||
logger.debug(
|
||||
f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm"
|
||||
)
|
||||
os.environ["TMPDIR"] = "/dev/shm"
|
||||
tempfile.tempdir = None # Clear cache to force re-evaluation
|
||||
|
||||
|
||||
class ComfyUIAdapter(IsolationAdapter):
|
||||
# ComfyUI-specific IsolationAdapter implementation
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
return "comfyui"
|
||||
|
||||
def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]:
|
||||
if "ComfyUI" in module_path and "custom_nodes" in module_path:
|
||||
parts = module_path.split("ComfyUI")
|
||||
if len(parts) > 1:
|
||||
comfy_root = parts[0] + "ComfyUI"
|
||||
return {
|
||||
"preferred_root": comfy_root,
|
||||
"additional_paths": [
|
||||
os.path.join(comfy_root, "custom_nodes"),
|
||||
os.path.join(comfy_root, "comfy"),
|
||||
],
|
||||
"filtered_subdirs": ["comfy", "app", "comfy_execution", "utils"],
|
||||
}
|
||||
return None
|
||||
|
||||
def get_sandbox_system_paths(self) -> Optional[List[str]]:
|
||||
"""Returns required application paths to mount in the sandbox."""
|
||||
# By inspecting where our adapter is loaded from, we can determine the comfy root
|
||||
adapter_file = inspect.getfile(self.__class__)
|
||||
# adapter_file = /home/johnj/ComfyUI/comfy/isolation/adapter.py
|
||||
comfy_root = os.path.dirname(os.path.dirname(os.path.dirname(adapter_file)))
|
||||
if os.path.exists(comfy_root):
|
||||
return [comfy_root]
|
||||
return None
|
||||
|
||||
def setup_child_environment(self, snapshot: Dict[str, Any]) -> None:
|
||||
comfy_root = snapshot.get("preferred_root")
|
||||
if not comfy_root:
|
||||
return
|
||||
|
||||
requirements_path = Path(comfy_root) / "requirements.txt"
|
||||
if requirements_path.exists():
|
||||
import re
|
||||
|
||||
for line in requirements_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
pkg_name = re.split(r"[<>=!~\[]", line)[0].strip()
|
||||
if pkg_name:
|
||||
logging.getLogger(pkg_name).setLevel(logging.ERROR)
|
||||
|
||||
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
|
||||
if not _IMPORT_TORCH:
|
||||
# Sealed worker without torch — register torch-free TensorValue handler
|
||||
# so IMAGE/MASK/LATENT tensors arrive as numpy arrays, not raw dicts.
|
||||
import numpy as np
|
||||
|
||||
_TORCH_DTYPE_TO_NUMPY = {
|
||||
"torch.float32": np.float32,
|
||||
"torch.float64": np.float64,
|
||||
"torch.float16": np.float16,
|
||||
"torch.bfloat16": np.float32, # numpy has no bfloat16; upcast
|
||||
"torch.int32": np.int32,
|
||||
"torch.int64": np.int64,
|
||||
"torch.int16": np.int16,
|
||||
"torch.int8": np.int8,
|
||||
"torch.uint8": np.uint8,
|
||||
"torch.bool": np.bool_,
|
||||
}
|
||||
|
||||
def _deserialize_tensor_value(data: Dict[str, Any]) -> Any:
|
||||
dtype_str = data["dtype"]
|
||||
np_dtype = _TORCH_DTYPE_TO_NUMPY.get(dtype_str, np.float32)
|
||||
shape = tuple(data["tensor_size"])
|
||||
arr = np.array(data["data"], dtype=np_dtype).reshape(shape)
|
||||
return arr
|
||||
|
||||
_NUMPY_TO_TORCH_DTYPE = {
|
||||
np.float32: "torch.float32",
|
||||
np.float64: "torch.float64",
|
||||
np.float16: "torch.float16",
|
||||
np.int32: "torch.int32",
|
||||
np.int64: "torch.int64",
|
||||
np.int16: "torch.int16",
|
||||
np.int8: "torch.int8",
|
||||
np.uint8: "torch.uint8",
|
||||
np.bool_: "torch.bool",
|
||||
}
|
||||
|
||||
def _serialize_tensor_value(obj: Any) -> Dict[str, Any]:
|
||||
arr = np.asarray(obj, dtype=np.float32) if obj.dtype not in _NUMPY_TO_TORCH_DTYPE else np.asarray(obj)
|
||||
dtype_str = _NUMPY_TO_TORCH_DTYPE.get(arr.dtype.type, "torch.float32")
|
||||
return {
|
||||
"__type__": "TensorValue",
|
||||
"dtype": dtype_str,
|
||||
"tensor_size": list(arr.shape),
|
||||
"requires_grad": False,
|
||||
"data": arr.tolist(),
|
||||
}
|
||||
|
||||
registry.register("TensorValue", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
|
||||
# ndarray output from sealed workers serializes as TensorValue for host torch reconstruction
|
||||
registry.register("ndarray", _serialize_tensor_value, _deserialize_tensor_value, data_type=True)
|
||||
return
|
||||
|
||||
import torch
|
||||
|
||||
def serialize_device(obj: Any) -> Dict[str, Any]:
|
||||
return {"__type__": "device", "device_str": str(obj)}
|
||||
|
||||
def deserialize_device(data: Dict[str, Any]) -> Any:
|
||||
return torch.device(data["device_str"])
|
||||
|
||||
registry.register("device", serialize_device, deserialize_device)
|
||||
|
||||
_VALID_DTYPES = {
|
||||
"float16", "float32", "float64", "bfloat16",
|
||||
"int8", "int16", "int32", "int64",
|
||||
"uint8", "bool",
|
||||
}
|
||||
|
||||
def serialize_dtype(obj: Any) -> Dict[str, Any]:
|
||||
return {"__type__": "dtype", "dtype_str": str(obj)}
|
||||
|
||||
def deserialize_dtype(data: Dict[str, Any]) -> Any:
|
||||
dtype_name = data["dtype_str"].replace("torch.", "")
|
||||
if dtype_name not in _VALID_DTYPES:
|
||||
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
|
||||
return getattr(torch, dtype_name)
|
||||
|
||||
registry.register("dtype", serialize_dtype, deserialize_dtype)
|
||||
|
||||
from comfy_api.latest._io import FolderType
|
||||
from comfy_api.latest._ui import SavedImages, SavedResult
|
||||
|
||||
def serialize_saved_result(obj: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"__type__": "SavedResult",
|
||||
"filename": obj.filename,
|
||||
"subfolder": obj.subfolder,
|
||||
"folder_type": obj.type.value,
|
||||
}
|
||||
|
||||
def deserialize_saved_result(data: Dict[str, Any]) -> Any:
|
||||
if isinstance(data, SavedResult):
|
||||
return data
|
||||
folder_type = data["folder_type"] if "folder_type" in data else data["type"]
|
||||
return SavedResult(
|
||||
filename=data["filename"],
|
||||
subfolder=data["subfolder"],
|
||||
type=FolderType(folder_type),
|
||||
)
|
||||
|
||||
registry.register(
|
||||
"SavedResult",
|
||||
serialize_saved_result,
|
||||
deserialize_saved_result,
|
||||
data_type=True,
|
||||
)
|
||||
|
||||
def serialize_saved_images(obj: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"__type__": "SavedImages",
|
||||
"results": [serialize_saved_result(result) for result in obj.results],
|
||||
"is_animated": obj.is_animated,
|
||||
}
|
||||
|
||||
def deserialize_saved_images(data: Dict[str, Any]) -> Any:
|
||||
return SavedImages(
|
||||
results=[deserialize_saved_result(result) for result in data["results"]],
|
||||
is_animated=data.get("is_animated", False),
|
||||
)
|
||||
|
||||
registry.register(
|
||||
"SavedImages",
|
||||
serialize_saved_images,
|
||||
deserialize_saved_images,
|
||||
data_type=True,
|
||||
)
|
||||
|
||||
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
|
||||
# Child-side: must already have _instance_id (proxy)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||
raise RuntimeError(
|
||||
f"ModelPatcher in child lacks _instance_id: "
|
||||
f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
)
|
||||
# Host-side: register with registry
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id}
|
||||
model_id = ModelPatcherRegistry().register(obj)
|
||||
return {"__type__": "ModelPatcherRef", "model_id": model_id}
|
||||
|
||||
def deserialize_model_patcher(data: Any) -> Any:
|
||||
"""Deserialize ModelPatcher refs; pass through already-materialized objects."""
|
||||
if isinstance(data, dict):
|
||||
return ModelPatcherProxy(
|
||||
data["model_id"], registry=None, manage_lifecycle=False
|
||||
)
|
||||
return data
|
||||
|
||||
def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware ModelPatcherRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return ModelPatcherProxy(
|
||||
data["model_id"], registry=None, manage_lifecycle=False
|
||||
)
|
||||
else:
|
||||
return ModelPatcherRegistry()._get_instance(data["model_id"])
|
||||
|
||||
# Register ModelPatcher type for serialization
|
||||
registry.register(
|
||||
"ModelPatcher", serialize_model_patcher, deserialize_model_patcher
|
||||
)
|
||||
# Register ModelPatcherProxy type (already a proxy, just return ref)
|
||||
registry.register(
|
||||
"ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher
|
||||
)
|
||||
# Register ModelPatcherRef for deserialization (context-aware: host or child)
|
||||
registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref)
|
||||
|
||||
def serialize_clip(obj: Any) -> Dict[str, Any]:
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "CLIPRef", "clip_id": obj._instance_id}
|
||||
clip_id = CLIPRegistry().register(obj)
|
||||
return {"__type__": "CLIPRef", "clip_id": clip_id}
|
||||
|
||||
def deserialize_clip(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||
return data
|
||||
|
||||
def deserialize_clip_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware CLIPRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False)
|
||||
else:
|
||||
return CLIPRegistry()._get_instance(data["clip_id"])
|
||||
|
||||
# Register CLIP type for serialization
|
||||
registry.register("CLIP", serialize_clip, deserialize_clip)
|
||||
# Register CLIPProxy type (already a proxy, just return ref)
|
||||
registry.register("CLIPProxy", serialize_clip, deserialize_clip)
|
||||
# Register CLIPRef for deserialization (context-aware: host or child)
|
||||
registry.register("CLIPRef", None, deserialize_clip_ref)
|
||||
|
||||
def serialize_vae(obj: Any) -> Dict[str, Any]:
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "VAERef", "vae_id": obj._instance_id}
|
||||
vae_id = VAERegistry().register(obj)
|
||||
return {"__type__": "VAERef", "vae_id": vae_id}
|
||||
|
||||
def deserialize_vae(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return VAEProxy(data["vae_id"])
|
||||
return data
|
||||
|
||||
def deserialize_vae_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware VAERef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
# Child: create a proxy
|
||||
return VAEProxy(data["vae_id"])
|
||||
else:
|
||||
# Host: lookup real VAE from registry
|
||||
return VAERegistry()._get_instance(data["vae_id"])
|
||||
|
||||
# Register VAE type for serialization
|
||||
registry.register("VAE", serialize_vae, deserialize_vae)
|
||||
# Register VAEProxy type (already a proxy, just return ref)
|
||||
registry.register("VAEProxy", serialize_vae, deserialize_vae)
|
||||
# Register VAERef for deserialization (context-aware: host or child)
|
||||
registry.register("VAERef", None, deserialize_vae_ref)
|
||||
|
||||
# ModelSampling serialization - handles ModelSampling* types
|
||||
# copyreg removed - no pickle fallback allowed
|
||||
|
||||
def serialize_model_sampling(obj: Any) -> Dict[str, Any]:
|
||||
# Proxy with _instance_id — return ref (works from both host and child)
|
||||
if hasattr(obj, "_instance_id"):
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
|
||||
# Child-side: object created locally in child (e.g. ModelSamplingAdvanced
|
||||
# in nodes_z_image_turbo.py). Serialize as inline data so the host can
|
||||
# reconstruct the real torch.nn.Module.
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
import base64
|
||||
import io as _io
|
||||
|
||||
# Identify base classes from comfy.model_sampling
|
||||
bases = []
|
||||
for base in type(obj).__mro__:
|
||||
if base.__module__ == "comfy.model_sampling" and base.__name__ != "object":
|
||||
bases.append(base.__name__)
|
||||
# Serialize state_dict as base64 safetensors-like
|
||||
sd = obj.state_dict()
|
||||
sd_serialized = {}
|
||||
for k, v in sd.items():
|
||||
buf = _io.BytesIO()
|
||||
torch.save(v, buf)
|
||||
sd_serialized[k] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
# Capture plain attrs (shift, multiplier, sigma_data, etc.)
|
||||
plain_attrs = {}
|
||||
for k, v in obj.__dict__.items():
|
||||
if k.startswith("_"):
|
||||
continue
|
||||
if isinstance(v, (bool, int, float, str)):
|
||||
plain_attrs[k] = v
|
||||
return {
|
||||
"__type__": "ModelSamplingInline",
|
||||
"bases": bases,
|
||||
"state_dict": sd_serialized,
|
||||
"attrs": plain_attrs,
|
||||
}
|
||||
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
|
||||
ms_id = ModelSamplingRegistry().register(obj)
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
|
||||
|
||||
def deserialize_model_sampling(data: Any) -> Any:
|
||||
"""Deserialize ModelSampling refs or inline data."""
|
||||
if isinstance(data, dict):
|
||||
if data.get("__type__") == "ModelSamplingInline":
|
||||
return _reconstruct_model_sampling_inline(data)
|
||||
return ModelSamplingProxy(data["ms_id"])
|
||||
return data
|
||||
|
||||
def _reconstruct_model_sampling_inline(data: Dict[str, Any]) -> Any:
|
||||
"""Reconstruct a ModelSampling object on the host from inline child data."""
|
||||
import comfy.model_sampling as _ms
|
||||
import base64
|
||||
import io as _io
|
||||
|
||||
# Resolve base classes
|
||||
base_classes = []
|
||||
for name in data["bases"]:
|
||||
cls = getattr(_ms, name, None)
|
||||
if cls is not None:
|
||||
base_classes.append(cls)
|
||||
if not base_classes:
|
||||
raise RuntimeError(
|
||||
f"Cannot reconstruct ModelSampling: no known bases in {data['bases']}"
|
||||
)
|
||||
# Create dynamic class matching the child's class hierarchy
|
||||
ReconstructedSampling = type("ReconstructedSampling", tuple(base_classes), {})
|
||||
obj = ReconstructedSampling.__new__(ReconstructedSampling)
|
||||
torch.nn.Module.__init__(obj)
|
||||
# Restore plain attributes first
|
||||
for k, v in data.get("attrs", {}).items():
|
||||
setattr(obj, k, v)
|
||||
# Restore state_dict (buffers like sigmas)
|
||||
for k, v_b64 in data.get("state_dict", {}).items():
|
||||
buf = _io.BytesIO(base64.b64decode(v_b64))
|
||||
tensor = torch.load(buf, weights_only=True)
|
||||
# Register as buffer so it's part of state_dict
|
||||
parts = k.split(".")
|
||||
if len(parts) == 1:
|
||||
cast(Any, obj).register_buffer(parts[0], tensor) # pylint: disable=no-member
|
||||
else:
|
||||
setattr(obj, parts[0], tensor)
|
||||
# Register on host so future references use proxy pattern.
|
||||
# Skip in child process — register() is async RPC and cannot be
|
||||
# called synchronously during deserialization.
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
ModelSamplingRegistry().register(obj)
|
||||
return obj
|
||||
|
||||
def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any:
|
||||
"""Context-aware ModelSamplingRef deserializer for both host and child."""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if is_child:
|
||||
return ModelSamplingProxy(data["ms_id"])
|
||||
else:
|
||||
return ModelSamplingRegistry()._get_instance(data["ms_id"])
|
||||
|
||||
# Register all ModelSampling* and StableCascadeSampling classes dynamically
|
||||
import comfy.model_sampling
|
||||
|
||||
for ms_cls in vars(comfy.model_sampling).values():
|
||||
if not isinstance(ms_cls, type):
|
||||
continue
|
||||
if not issubclass(ms_cls, torch.nn.Module):
|
||||
continue
|
||||
if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
|
||||
continue
|
||||
registry.register(
|
||||
ms_cls.__name__,
|
||||
serialize_model_sampling,
|
||||
deserialize_model_sampling,
|
||||
)
|
||||
registry.register(
|
||||
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
|
||||
)
|
||||
# Register ModelSamplingRef for deserialization (context-aware: host or child)
|
||||
registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref)
|
||||
# Register ModelSamplingInline for deserialization (child→host inline transfer)
|
||||
registry.register(
|
||||
"ModelSamplingInline", None, lambda data: _reconstruct_model_sampling_inline(data)
|
||||
)
|
||||
|
||||
def serialize_cond(obj: Any) -> Dict[str, Any]:
|
||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
return {
|
||||
"__type__": type_key,
|
||||
"cond": obj.cond,
|
||||
}
|
||||
|
||||
def deserialize_cond(data: Dict[str, Any]) -> Any:
|
||||
import importlib
|
||||
|
||||
type_key = data["__type__"]
|
||||
module_name, class_name = type_key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(data["cond"])
|
||||
|
||||
def _serialize_public_state(obj: Any) -> Dict[str, Any]:
|
||||
state: Dict[str, Any] = {}
|
||||
for key, value in obj.__dict__.items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if callable(value):
|
||||
continue
|
||||
state[key] = value
|
||||
return state
|
||||
|
||||
def serialize_latent_format(obj: Any) -> Dict[str, Any]:
|
||||
type_key = f"{type(obj).__module__}.{type(obj).__name__}"
|
||||
return {
|
||||
"__type__": type_key,
|
||||
"state": _serialize_public_state(obj),
|
||||
}
|
||||
|
||||
def deserialize_latent_format(data: Dict[str, Any]) -> Any:
|
||||
import importlib
|
||||
|
||||
type_key = data["__type__"]
|
||||
module_name, class_name = type_key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
cls = getattr(module, class_name)
|
||||
obj = cls()
|
||||
for key, value in data.get("state", {}).items():
|
||||
prop = getattr(type(obj), key, None)
|
||||
if isinstance(prop, property) and prop.fset is None:
|
||||
continue
|
||||
setattr(obj, key, value)
|
||||
return obj
|
||||
|
||||
import comfy.conds
|
||||
|
||||
for cond_cls in vars(comfy.conds).values():
|
||||
if not isinstance(cond_cls, type):
|
||||
continue
|
||||
if not issubclass(cond_cls, comfy.conds.CONDRegular):
|
||||
continue
|
||||
type_key = f"{cond_cls.__module__}.{cond_cls.__name__}"
|
||||
registry.register(type_key, serialize_cond, deserialize_cond)
|
||||
registry.register(cond_cls.__name__, serialize_cond, deserialize_cond)
|
||||
|
||||
import comfy.latent_formats
|
||||
|
||||
for latent_cls in vars(comfy.latent_formats).values():
|
||||
if not isinstance(latent_cls, type):
|
||||
continue
|
||||
if not issubclass(latent_cls, comfy.latent_formats.LatentFormat):
|
||||
continue
|
||||
type_key = f"{latent_cls.__module__}.{latent_cls.__name__}"
|
||||
registry.register(
|
||||
type_key, serialize_latent_format, deserialize_latent_format
|
||||
)
|
||||
registry.register(
|
||||
latent_cls.__name__, serialize_latent_format, deserialize_latent_format
|
||||
)
|
||||
|
||||
# V3 API: unwrap NodeOutput.args
|
||||
def deserialize_node_output(data: Any) -> Any:
|
||||
return getattr(data, "args", data)
|
||||
|
||||
registry.register("NodeOutput", None, deserialize_node_output)
|
||||
|
||||
# KSAMPLER serializer: stores sampler name instead of function object
|
||||
# sampler_function is a callable which gets filtered out by JSONSocketTransport
|
||||
def serialize_ksampler(obj: Any) -> Dict[str, Any]:
|
||||
func_name = obj.sampler_function.__name__
|
||||
# Map function name back to sampler name
|
||||
if func_name == "sample_unipc":
|
||||
sampler_name = "uni_pc"
|
||||
elif func_name == "sample_unipc_bh2":
|
||||
sampler_name = "uni_pc_bh2"
|
||||
elif func_name == "dpm_fast_function":
|
||||
sampler_name = "dpm_fast"
|
||||
elif func_name == "dpm_adaptive_function":
|
||||
sampler_name = "dpm_adaptive"
|
||||
elif func_name.startswith("sample_"):
|
||||
sampler_name = func_name[7:] # Remove "sample_" prefix
|
||||
else:
|
||||
sampler_name = func_name
|
||||
return {
|
||||
"__type__": "KSAMPLER",
|
||||
"sampler_name": sampler_name,
|
||||
"extra_options": obj.extra_options,
|
||||
"inpaint_options": obj.inpaint_options,
|
||||
}
|
||||
|
||||
def deserialize_ksampler(data: Dict[str, Any]) -> Any:
|
||||
import comfy.samplers
|
||||
|
||||
return comfy.samplers.ksampler(
|
||||
data["sampler_name"],
|
||||
data.get("extra_options", {}),
|
||||
data.get("inpaint_options", {}),
|
||||
)
|
||||
|
||||
registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler)
|
||||
|
||||
from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers
|
||||
|
||||
register_hooks_serializers(registry)
|
||||
|
||||
# Generic Numpy Serializer
|
||||
def serialize_numpy(obj: Any) -> Any:
|
||||
import torch
|
||||
|
||||
try:
|
||||
# Attempt zero-copy conversion to Tensor
|
||||
return torch.from_numpy(obj)
|
||||
except Exception:
|
||||
# Fallback for non-numeric arrays (strings, objects, mixes)
|
||||
return obj.tolist()
|
||||
|
||||
def deserialize_numpy_b64(data: Any) -> Any:
|
||||
"""Deserialize base64-encoded ndarray from sealed worker."""
|
||||
import base64
|
||||
import numpy as np
|
||||
if isinstance(data, dict) and "data" in data and "dtype" in data:
|
||||
raw = base64.b64decode(data["data"])
|
||||
arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"])
|
||||
return torch.from_numpy(arr.copy())
|
||||
return data
|
||||
|
||||
registry.register("ndarray", serialize_numpy, deserialize_numpy_b64)
|
||||
|
||||
# -- File3D (comfy_api.latest._util.geometry_types) ---------------------
|
||||
# Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129
|
||||
|
||||
def serialize_file3d(obj: Any) -> Dict[str, Any]:
|
||||
import base64
|
||||
return {
|
||||
"__type__": "File3D",
|
||||
"format": obj.format,
|
||||
"data": base64.b64encode(obj.get_bytes()).decode("ascii"),
|
||||
}
|
||||
|
||||
def deserialize_file3d(data: Any) -> Any:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from comfy_api.latest._util.geometry_types import File3D
|
||||
return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"])
|
||||
|
||||
registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True)
|
||||
|
||||
# -- VIDEO (comfy_api.latest._input_impl.video_types) -------------------
|
||||
# Origin: ComfyAPI Core v0.0.2 by ComfyOrg (guill), PR #8962
|
||||
|
||||
def serialize_video(obj: Any) -> Dict[str, Any]:
|
||||
components = obj.get_components()
|
||||
images = components.images.detach() if components.images.requires_grad else components.images
|
||||
result: Dict[str, Any] = {
|
||||
"__type__": "VIDEO",
|
||||
"images": images,
|
||||
"frame_rate_num": components.frame_rate.numerator,
|
||||
"frame_rate_den": components.frame_rate.denominator,
|
||||
}
|
||||
if components.audio is not None:
|
||||
waveform = components.audio["waveform"]
|
||||
if waveform.requires_grad:
|
||||
waveform = waveform.detach()
|
||||
result["audio_waveform"] = waveform
|
||||
result["audio_sample_rate"] = components.audio["sample_rate"]
|
||||
if components.metadata is not None:
|
||||
result["metadata"] = components.metadata
|
||||
return result
|
||||
|
||||
def deserialize_video(data: Any) -> Any:
|
||||
from fractions import Fraction
|
||||
from comfy_api.latest._input_impl.video_types import VideoFromComponents
|
||||
from comfy_api.latest._util.video_types import VideoComponents
|
||||
audio = None
|
||||
if "audio_waveform" in data:
|
||||
audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]}
|
||||
components = VideoComponents(
|
||||
images=data["images"],
|
||||
frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]),
|
||||
audio=audio,
|
||||
metadata=data.get("metadata"),
|
||||
)
|
||||
return VideoFromComponents(components)
|
||||
|
||||
registry.register("VIDEO", serialize_video, deserialize_video, data_type=True)
|
||||
registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True)
|
||||
registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True)
|
||||
|
||||
def setup_web_directory(self, module: Any) -> None:
|
||||
"""Detect WEB_DIRECTORY on a module and populate/register it.
|
||||
|
||||
Called by the sealed worker after loading the node module.
|
||||
Mirrors extension_wrapper.py:216-227 for host-coupled nodes.
|
||||
Does NOT import extension_wrapper.py (it has `import torch` at module level).
|
||||
"""
|
||||
import shutil
|
||||
|
||||
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
|
||||
if web_dir_attr is None:
|
||||
return
|
||||
|
||||
module_dir = os.path.dirname(os.path.abspath(module.__file__))
|
||||
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
|
||||
|
||||
# Read extension name from pyproject.toml
|
||||
ext_name = os.path.basename(module_dir)
|
||||
pyproject = os.path.join(module_dir, "pyproject.toml")
|
||||
if os.path.exists(pyproject):
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
try:
|
||||
with open(pyproject, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
name = data.get("project", {}).get("name")
|
||||
if name:
|
||||
ext_name = name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Populate web dir if empty (mirrors _run_prestartup_web_copy)
|
||||
if not (os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path))):
|
||||
os.makedirs(web_dir_path, exist_ok=True)
|
||||
|
||||
# Module-defined copy spec
|
||||
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
|
||||
if copy_spec is not None and callable(copy_spec):
|
||||
try:
|
||||
copy_spec(web_dir_path)
|
||||
except Exception as e:
|
||||
logger.warning("][ _PRESTARTUP_WEB_COPY failed: %s", e)
|
||||
|
||||
# Fallback: comfy_3d_viewers
|
||||
try:
|
||||
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
|
||||
for viewer in VIEWER_FILES:
|
||||
try:
|
||||
copy_viewer(viewer, web_dir_path)
|
||||
except Exception:
|
||||
pass
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback: comfy_dynamic_widgets
|
||||
try:
|
||||
from comfy_dynamic_widgets import get_js_path
|
||||
src = os.path.realpath(get_js_path())
|
||||
if os.path.exists(src):
|
||||
dst_dir = os.path.join(web_dir_path, "js")
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
shutil.copy2(src, os.path.join(dst_dir, "dynamic_widgets.js"))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
|
||||
logger.info(
|
||||
"][ Adapter: registered web dir for %s (%d files)",
|
||||
ext_name,
|
||||
sum(1 for _ in Path(web_dir_path).rglob("*") if _.is_file()),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_host_event_handlers(extension: Any) -> None:
|
||||
"""Register host-side event handlers for an isolated extension.
|
||||
|
||||
Wires ``"progress"`` events from the child to ``comfy.utils.PROGRESS_BAR_HOOK``
|
||||
so the ComfyUI frontend receives progress bar updates.
|
||||
"""
|
||||
register_event_handler = inspect.getattr_static(
|
||||
extension, "register_event_handler", None
|
||||
)
|
||||
if not callable(register_event_handler):
|
||||
return
|
||||
|
||||
def _host_progress_handler(payload: dict) -> None:
|
||||
import comfy.utils
|
||||
|
||||
hook = comfy.utils.PROGRESS_BAR_HOOK
|
||||
if hook is not None:
|
||||
hook(
|
||||
payload.get("value", 0),
|
||||
payload.get("total", 0),
|
||||
payload.get("preview"),
|
||||
payload.get("node_id"),
|
||||
)
|
||||
|
||||
extension.register_event_handler("progress", _host_progress_handler)
|
||||
|
||||
def setup_child_event_hooks(self, extension: Any) -> None:
|
||||
"""Wire PROGRESS_BAR_HOOK in the child to emit_event on the extension.
|
||||
|
||||
Host-coupled only — sealed workers do not have comfy.utils (torch).
|
||||
"""
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
logger.info("][ ISO:setup_child_event_hooks called, PYISOLATE_CHILD=%s", is_child)
|
||||
if not is_child:
|
||||
return
|
||||
|
||||
if not _IMPORT_TORCH:
|
||||
logger.info("][ ISO:setup_child_event_hooks skipped — sealed worker (no torch)")
|
||||
return
|
||||
|
||||
import comfy.utils
|
||||
|
||||
def _event_progress_hook(value, total, preview=None, node_id=None):
|
||||
logger.debug("][ ISO:event_progress value=%s/%s node_id=%s", value, total, node_id)
|
||||
extension.emit_event("progress", {
|
||||
"value": value,
|
||||
"total": total,
|
||||
"node_id": node_id,
|
||||
})
|
||||
|
||||
comfy.utils.PROGRESS_BAR_HOOK = _event_progress_hook
|
||||
logger.info("][ ISO:PROGRESS_BAR_HOOK wired to event channel")
|
||||
|
||||
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
|
||||
# Always available — no torch/PIL dependency
|
||||
services: List[type[ProxiedSingleton]] = [
|
||||
FolderPathsProxy,
|
||||
HelperProxiesService,
|
||||
WebDirectoryProxy,
|
||||
]
|
||||
# Torch/PIL-dependent proxies
|
||||
if _HAS_TORCH_PROXIES:
|
||||
services.extend([
|
||||
PromptServerService,
|
||||
ModelManagementProxy,
|
||||
UtilsProxy,
|
||||
ProgressProxy,
|
||||
VAERegistry,
|
||||
CLIPRegistry,
|
||||
ModelPatcherRegistry,
|
||||
ModelSamplingRegistry,
|
||||
FirstStageModelRegistry,
|
||||
])
|
||||
return services
|
||||
|
||||
def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None:
|
||||
# Resolve the real name whether it's an instance or the Singleton class itself
|
||||
api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__
|
||||
|
||||
if api_name == "FolderPathsProxy":
|
||||
import folder_paths
|
||||
|
||||
# Replace module-level functions with proxy methods
|
||||
# This is aggressive but necessary for transparent proxying
|
||||
# Handle both instance and class cases
|
||||
instance = api() if isinstance(api, type) else api
|
||||
for name in dir(instance):
|
||||
if not name.startswith("_"):
|
||||
setattr(folder_paths, name, getattr(instance, name))
|
||||
|
||||
# Fence: isolated children get writable temp inside sandbox
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
import tempfile
|
||||
_child_temp = os.path.join(tempfile.gettempdir(), "comfyui_temp")
|
||||
os.makedirs(_child_temp, exist_ok=True)
|
||||
folder_paths.temp_directory = _child_temp
|
||||
|
||||
return
|
||||
|
||||
if api_name == "ModelManagementProxy":
|
||||
if _IMPORT_TORCH:
|
||||
import comfy.model_management
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
# Replace module-level functions with proxy methods
|
||||
for name in dir(instance):
|
||||
if not name.startswith("_"):
|
||||
setattr(comfy.model_management, name, getattr(instance, name))
|
||||
return
|
||||
|
||||
if api_name == "UtilsProxy":
|
||||
if not _IMPORT_TORCH:
|
||||
logger.info("][ ISO:UtilsProxy handle_api_registration skipped — sealed worker (no torch)")
|
||||
return
|
||||
|
||||
import comfy.utils
|
||||
|
||||
# Static Injection of RPC mechanism to ensure Child can access it
|
||||
# independent of instance lifecycle.
|
||||
api.set_rpc(rpc)
|
||||
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
logger.info("][ ISO:UtilsProxy handle_api_registration PYISOLATE_CHILD=%s", is_child)
|
||||
|
||||
# Progress hook wiring moved to setup_child_event_hooks via event channel
|
||||
|
||||
return
|
||||
|
||||
if api_name == "PromptServerProxy":
|
||||
if not _IMPORT_TORCH:
|
||||
return
|
||||
# Defer heavy import to child context
|
||||
import server
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
proxy = (
|
||||
instance.instance
|
||||
) # PromptServerProxy instance has .instance property returning self
|
||||
|
||||
original_register_route = proxy.register_route
|
||||
|
||||
def register_route_wrapper(
|
||||
method: str, path: str, handler: Callable[..., Any]
|
||||
) -> None:
|
||||
callback_id = rpc.register_callback(handler)
|
||||
loop = getattr(rpc, "loop", None)
|
||||
if loop and loop.is_running():
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
return None
|
||||
|
||||
proxy.register_route = register_route_wrapper
|
||||
|
||||
class RouteTableDefProxy:
|
||||
def __init__(self, proxy_instance: Any):
|
||||
self.proxy = proxy_instance
|
||||
|
||||
def get(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
proxy.routes = RouteTableDefProxy(proxy)
|
||||
|
||||
if (
|
||||
hasattr(server, "PromptServer")
|
||||
and getattr(server.PromptServer, "instance", None) != proxy
|
||||
):
|
||||
server.PromptServer.instance = proxy
|
||||
101
comfy/isolation/child_hooks.py
Normal file
101
comfy/isolation/child_hooks.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation
|
||||
# Child process initialization for PyIsolate
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
def initialize_child_process() -> None:
|
||||
_setup_child_loop_bridge()
|
||||
|
||||
# Manual RPC injection
|
||||
try:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc:
|
||||
_setup_proxy_callers(rpc)
|
||||
else:
|
||||
logger.warning("Could not get child RPC instance for manual injection")
|
||||
_setup_proxy_callers()
|
||||
except Exception as e:
|
||||
logger.error(f"Manual RPC Injection failed: {e}")
|
||||
_setup_proxy_callers()
|
||||
|
||||
_setup_logging()
|
||||
|
||||
|
||||
def _setup_child_loop_bridge() -> None:
|
||||
import asyncio
|
||||
|
||||
main_loop = None
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
try:
|
||||
main_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if main_loop is None:
|
||||
return
|
||||
|
||||
try:
|
||||
from .proxies.base import set_global_loop
|
||||
|
||||
set_global_loop(main_loop)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _setup_prompt_server_stub(rpc=None) -> None:
|
||||
try:
|
||||
from .proxies.prompt_server_impl import PromptServerStub
|
||||
|
||||
if rpc:
|
||||
PromptServerStub.set_rpc(rpc)
|
||||
elif hasattr(PromptServerStub, "clear_rpc"):
|
||||
PromptServerStub.clear_rpc()
|
||||
else:
|
||||
PromptServerStub._rpc = None # type: ignore[attr-defined]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup PromptServerStub: {e}")
|
||||
|
||||
|
||||
def _setup_proxy_callers(rpc=None) -> None:
|
||||
try:
|
||||
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from .proxies.helper_proxies import HelperProxiesService
|
||||
from .proxies.model_management_proxy import ModelManagementProxy
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
from .proxies.prompt_server_impl import PromptServerStub
|
||||
from .proxies.utils_proxy import UtilsProxy
|
||||
|
||||
if rpc is None:
|
||||
FolderPathsProxy.clear_rpc()
|
||||
HelperProxiesService.clear_rpc()
|
||||
ModelManagementProxy.clear_rpc()
|
||||
ProgressProxy.clear_rpc()
|
||||
PromptServerStub.clear_rpc()
|
||||
UtilsProxy.clear_rpc()
|
||||
return
|
||||
|
||||
FolderPathsProxy.set_rpc(rpc)
|
||||
HelperProxiesService.set_rpc(rpc)
|
||||
ModelManagementProxy.set_rpc(rpc)
|
||||
ProgressProxy.set_rpc(rpc)
|
||||
PromptServerStub.set_rpc(rpc)
|
||||
UtilsProxy.set_rpc(rpc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup child singleton proxy callers: {e}")
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
327
comfy/isolation/clip_proxy.py
Normal file
327
comfy/isolation/clip_proxy.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation
|
||||
# CLIP Proxy implementation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
|
||||
class CondStageModelRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "cond_stage_model"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
|
||||
class CondStageModelProxy(BaseProxy[CondStageModelRegistry]):
|
||||
_registry_class = CondStageModelRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CondStageModelProxy {self._instance_id}>"
|
||||
|
||||
|
||||
class TokenizerRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "tokenizer"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
|
||||
class TokenizerProxy(BaseProxy[TokenizerRegistry]):
|
||||
_registry_class = TokenizerRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TokenizerProxy {self._instance_id}>"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLIPRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "clip"
|
||||
_allowed_setters = {
|
||||
"layer_idx",
|
||||
"tokenizer_options",
|
||||
"use_clip_schedule",
|
||||
"apply_hooks_to_conds",
|
||||
}
|
||||
|
||||
async def get_ram_usage(self, instance_id: str) -> int:
|
||||
return self._get_instance(instance_id).get_ram_usage()
|
||||
|
||||
async def get_patcher_id(self, instance_id: str) -> str:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry
|
||||
|
||||
return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher)
|
||||
|
||||
async def get_cond_stage_model_id(self, instance_id: str) -> str:
|
||||
return CondStageModelRegistry().register(
|
||||
self._get_instance(instance_id).cond_stage_model
|
||||
)
|
||||
|
||||
async def get_tokenizer_id(self, instance_id: str) -> str:
|
||||
return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer)
|
||||
|
||||
async def load_model(self, instance_id: str) -> None:
|
||||
self._get_instance(instance_id).load_model()
|
||||
|
||||
async def clip_layer(self, instance_id: str, layer_idx: int) -> None:
|
||||
self._get_instance(instance_id).clip_layer(layer_idx)
|
||||
|
||||
async def set_tokenizer_option(
|
||||
self, instance_id: str, option_name: str, value: Any
|
||||
) -> None:
|
||||
self._get_instance(instance_id).set_tokenizer_option(option_name, value)
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), name)
|
||||
|
||||
async def set_property(self, instance_id: str, name: str, value: Any) -> None:
|
||||
if name not in self._allowed_setters:
|
||||
raise PermissionError(f"Setting '{name}' is not allowed via RPC")
|
||||
setattr(self._get_instance(instance_id), name, value)
|
||||
|
||||
async def tokenize(
|
||||
self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).tokenize(
|
||||
text, return_word_ids=return_word_ids, **kwargs
|
||||
)
|
||||
|
||||
async def encode(self, instance_id: str, text: str) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).encode(text))
|
||||
|
||||
async def encode_from_tokens(
|
||||
self,
|
||||
instance_id: str,
|
||||
tokens: Any,
|
||||
return_pooled: bool = False,
|
||||
return_dict: bool = False,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_from_tokens(
|
||||
tokens, return_pooled=return_pooled, return_dict=return_dict
|
||||
)
|
||||
)
|
||||
|
||||
async def encode_from_tokens_scheduled(
|
||||
self,
|
||||
instance_id: str,
|
||||
tokens: Any,
|
||||
unprojected: bool = False,
|
||||
add_dict: Optional[dict] = None,
|
||||
show_pbar: bool = True,
|
||||
) -> Any:
|
||||
add_dict = add_dict or {}
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_from_tokens_scheduled(
|
||||
tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar
|
||||
)
|
||||
)
|
||||
|
||||
async def add_patches(
|
||||
self,
|
||||
instance_id: str,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).add_patches(
|
||||
patches, strength_patch=strength_patch, strength_model=strength_model
|
||||
)
|
||||
|
||||
async def get_key_patches(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).get_key_patches()
|
||||
|
||||
async def load_sd(
|
||||
self, instance_id: str, sd: dict, full_model: bool = False
|
||||
) -> Any:
|
||||
return self._get_instance(instance_id).load_sd(sd, full_model=full_model)
|
||||
|
||||
async def get_sd(self, instance_id: str) -> Any:
|
||||
return self._get_instance(instance_id).get_sd()
|
||||
|
||||
async def clone(self, instance_id: str) -> str:
|
||||
return self.register(self._get_instance(instance_id).clone())
|
||||
|
||||
|
||||
class CLIPProxy(BaseProxy[CLIPRegistry]):
|
||||
_registry_class = CLIPRegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
def get_ram_usage(self) -> int:
|
||||
return self._call_rpc("get_ram_usage")
|
||||
|
||||
@property
|
||||
def patcher(self) -> "ModelPatcherProxy":
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
if not hasattr(self, "_patcher_proxy"):
|
||||
patcher_id = self._call_rpc("get_patcher_id")
|
||||
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
|
||||
return self._patcher_proxy
|
||||
|
||||
@patcher.setter
|
||||
def patcher(self, value: Any) -> None:
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
if isinstance(value, ModelPatcherProxy):
|
||||
self._patcher_proxy = value
|
||||
else:
|
||||
logger.warning(
|
||||
f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}"
|
||||
)
|
||||
|
||||
@property
|
||||
def cond_stage_model(self) -> CondStageModelProxy:
|
||||
if not hasattr(self, "_cond_stage_model_proxy"):
|
||||
csm_id = self._call_rpc("get_cond_stage_model_id")
|
||||
self._cond_stage_model_proxy = CondStageModelProxy(
|
||||
csm_id, manage_lifecycle=False
|
||||
)
|
||||
return self._cond_stage_model_proxy
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerProxy:
|
||||
if not hasattr(self, "_tokenizer_proxy"):
|
||||
tok_id = self._call_rpc("get_tokenizer_id")
|
||||
self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False)
|
||||
return self._tokenizer_proxy
|
||||
|
||||
def load_model(self) -> ModelPatcherProxy:
|
||||
self._call_rpc("load_model")
|
||||
return self.patcher
|
||||
|
||||
@property
|
||||
def layer_idx(self) -> Optional[int]:
|
||||
return self._call_rpc("get_property", "layer_idx")
|
||||
|
||||
@layer_idx.setter
|
||||
def layer_idx(self, value: Optional[int]) -> None:
|
||||
self._call_rpc("set_property", "layer_idx", value)
|
||||
|
||||
@property
|
||||
def tokenizer_options(self) -> dict:
|
||||
return self._call_rpc("get_property", "tokenizer_options")
|
||||
|
||||
@tokenizer_options.setter
|
||||
def tokenizer_options(self, value: dict) -> None:
|
||||
self._call_rpc("set_property", "tokenizer_options", value)
|
||||
|
||||
@property
|
||||
def use_clip_schedule(self) -> bool:
|
||||
return self._call_rpc("get_property", "use_clip_schedule")
|
||||
|
||||
@use_clip_schedule.setter
|
||||
def use_clip_schedule(self, value: bool) -> None:
|
||||
self._call_rpc("set_property", "use_clip_schedule", value)
|
||||
|
||||
@property
|
||||
def apply_hooks_to_conds(self) -> Any:
|
||||
return self._call_rpc("get_property", "apply_hooks_to_conds")
|
||||
|
||||
@apply_hooks_to_conds.setter
|
||||
def apply_hooks_to_conds(self, value: Any) -> None:
|
||||
self._call_rpc("set_property", "apply_hooks_to_conds", value)
|
||||
|
||||
def clip_layer(self, layer_idx: int) -> None:
|
||||
return self._call_rpc("clip_layer", layer_idx)
|
||||
|
||||
def set_tokenizer_option(self, option_name: str, value: Any) -> None:
|
||||
return self._call_rpc("set_tokenizer_option", option_name, value)
|
||||
|
||||
def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any:
|
||||
return self._call_rpc(
|
||||
"tokenize", text, return_word_ids=return_word_ids, **kwargs
|
||||
)
|
||||
|
||||
def encode(self, text: str) -> Any:
|
||||
return self._call_rpc("encode", text)
|
||||
|
||||
def encode_from_tokens(
|
||||
self, tokens: Any, return_pooled: bool = False, return_dict: bool = False
|
||||
) -> Any:
|
||||
res = self._call_rpc(
|
||||
"encode_from_tokens",
|
||||
tokens,
|
||||
return_pooled=return_pooled,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
if return_pooled and isinstance(res, list) and not return_dict:
|
||||
return tuple(res)
|
||||
return res
|
||||
|
||||
def encode_from_tokens_scheduled(
|
||||
self,
|
||||
tokens: Any,
|
||||
unprojected: bool = False,
|
||||
add_dict: Optional[dict] = None,
|
||||
show_pbar: bool = True,
|
||||
) -> Any:
|
||||
add_dict = add_dict or {}
|
||||
return self._call_rpc(
|
||||
"encode_from_tokens_scheduled",
|
||||
tokens,
|
||||
unprojected=unprojected,
|
||||
add_dict=add_dict,
|
||||
show_pbar=show_pbar,
|
||||
)
|
||||
|
||||
def add_patches(
|
||||
self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"add_patches",
|
||||
patches,
|
||||
strength_patch=strength_patch,
|
||||
strength_model=strength_model,
|
||||
)
|
||||
|
||||
def get_key_patches(self) -> Any:
|
||||
return self._call_rpc("get_key_patches")
|
||||
|
||||
def load_sd(self, sd: dict, full_model: bool = False) -> Any:
|
||||
return self._call_rpc("load_sd", sd, full_model=full_model)
|
||||
|
||||
def get_sd(self) -> Any:
|
||||
return self._call_rpc("get_sd")
|
||||
|
||||
def clone(self) -> CLIPProxy:
|
||||
new_id = self._call_rpc("clone")
|
||||
return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS)
|
||||
|
||||
|
||||
if not IS_CHILD_PROCESS:
|
||||
_CLIP_REGISTRY_SINGLETON = CLIPRegistry()
|
||||
_COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry()
|
||||
_TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry()
|
||||
16
comfy/isolation/custom_node_serializers.py
Normal file
16
comfy/isolation/custom_node_serializers.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Compatibility shim for the indexed serializer path."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def register_custom_node_serializers(_registry: Any) -> None:
|
||||
"""Legacy no-op shim.
|
||||
|
||||
Serializer registration now lives directly in the active isolation adapter.
|
||||
This module remains importable because the isolation index still references it.
|
||||
"""
|
||||
return None
|
||||
|
||||
__all__ = ["register_custom_node_serializers"]
|
||||
489
comfy/isolation/extension_loader.py
Normal file
489
comfy/isolation/extension_loader.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import pyisolate
|
||||
from pyisolate import ExtensionManager, ExtensionManagerConfig
|
||||
from packaging.requirements import InvalidRequirement, Requirement
|
||||
from packaging.utils import canonicalize_name
|
||||
|
||||
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
|
||||
from .host_policy import load_host_policy
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _register_web_directory(extension_name: str, node_dir: Path) -> None:
|
||||
"""Register an isolated extension's web directory on the host side."""
|
||||
import nodes
|
||||
|
||||
# Method 1: pyproject.toml [tool.comfy] web field
|
||||
pyproject = node_dir / "pyproject.toml"
|
||||
if pyproject.exists():
|
||||
try:
|
||||
with pyproject.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
web_dir_name = data.get("tool", {}).get("comfy", {}).get("web")
|
||||
if web_dir_name:
|
||||
web_dir_path = str(node_dir / web_dir_name)
|
||||
if os.path.isdir(web_dir_path):
|
||||
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||
logger.debug(
|
||||
"][ Registered web dir for isolated %s: %s",
|
||||
extension_name,
|
||||
web_dir_path,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Method 2: __init__.py WEB_DIRECTORY constant (parse without importing)
|
||||
init_file = node_dir / "__init__.py"
|
||||
if init_file.exists():
|
||||
try:
|
||||
source = init_file.read_text()
|
||||
for line in source.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("WEB_DIRECTORY"):
|
||||
# Parse: WEB_DIRECTORY = "./web" or WEB_DIRECTORY = "web"
|
||||
_, _, value = stripped.partition("=")
|
||||
value = value.strip().strip("\"'")
|
||||
if value:
|
||||
web_dir_path = str((node_dir / value).resolve())
|
||||
if os.path.isdir(web_dir_path):
|
||||
nodes.EXTENSION_WEB_DIRS[extension_name] = web_dir_path
|
||||
logger.debug(
|
||||
"][ Registered web dir for isolated %s: %s",
|
||||
extension_name,
|
||||
web_dir_path,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _get_extension_type(execution_model: str) -> type[Any]:
|
||||
if execution_model == "sealed_worker":
|
||||
return pyisolate.SealedNodeExtension
|
||||
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
return ComfyNodeExtension
|
||||
|
||||
|
||||
async def _stop_extension_safe(extension: Any, extension_name: str) -> None:
|
||||
try:
|
||||
stop_result = extension.stop()
|
||||
if inspect.isawaitable(stop_result):
|
||||
await stop_result
|
||||
except Exception:
|
||||
logger.debug("][ %s stop failed", extension_name, exc_info=True)
|
||||
|
||||
|
||||
def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
|
||||
req, sep, marker = dep.partition(";")
|
||||
req = req.strip()
|
||||
marker_suffix = f";{marker}" if sep else ""
|
||||
|
||||
def _resolve_local_path(local_path: str) -> Path | None:
|
||||
for base in base_paths:
|
||||
candidate = (base / local_path).resolve()
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
if req.startswith("./") or req.startswith("../"):
|
||||
resolved = _resolve_local_path(req)
|
||||
if resolved is not None:
|
||||
return f"{resolved}{marker_suffix}"
|
||||
|
||||
if req.startswith("file://"):
|
||||
raw = req[len("file://") :]
|
||||
if raw.startswith("./") or raw.startswith("../"):
|
||||
resolved = _resolve_local_path(raw)
|
||||
if resolved is not None:
|
||||
return f"file://{resolved}{marker_suffix}"
|
||||
|
||||
return dep
|
||||
|
||||
|
||||
def _dependency_name_from_spec(dep: str) -> str | None:
|
||||
stripped = dep.strip()
|
||||
if not stripped or stripped == "-e" or stripped.startswith("-e "):
|
||||
return None
|
||||
if stripped.startswith(("/", "./", "../", "file://")):
|
||||
return None
|
||||
|
||||
try:
|
||||
return canonicalize_name(Requirement(stripped).name)
|
||||
except InvalidRequirement:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_cuda_wheels_config(
|
||||
tool_config: dict[str, object], dependencies: list[str]
|
||||
) -> dict[str, object] | None:
|
||||
raw_config = tool_config.get("cuda_wheels")
|
||||
if raw_config is None:
|
||||
return None
|
||||
if not isinstance(raw_config, dict):
|
||||
raise ExtensionLoadError("[tool.comfy.isolation.cuda_wheels] must be a table")
|
||||
|
||||
index_url = raw_config.get("index_url")
|
||||
index_urls = raw_config.get("index_urls")
|
||||
if index_urls is not None:
|
||||
if not isinstance(index_urls, list) or not all(
|
||||
isinstance(u, str) and u.strip() for u in index_urls
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.index_urls] must be a list of non-empty strings"
|
||||
)
|
||||
elif not isinstance(index_url, str) or not index_url.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
|
||||
)
|
||||
|
||||
packages = raw_config.get("packages")
|
||||
if not isinstance(packages, list) or not all(
|
||||
isinstance(package_name, str) and package_name.strip()
|
||||
for package_name in packages
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings"
|
||||
)
|
||||
|
||||
declared_dependencies = {
|
||||
dependency_name
|
||||
for dep in dependencies
|
||||
if (dependency_name := _dependency_name_from_spec(dep)) is not None
|
||||
}
|
||||
normalized_packages = [canonicalize_name(package_name) for package_name in packages]
|
||||
missing = [
|
||||
package_name
|
||||
for package_name in normalized_packages
|
||||
if package_name not in declared_dependencies
|
||||
]
|
||||
if missing:
|
||||
missing_joined = ", ".join(sorted(missing))
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: "
|
||||
f"{missing_joined}"
|
||||
)
|
||||
|
||||
package_map = raw_config.get("package_map", {})
|
||||
if not isinstance(package_map, dict):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] must be a table"
|
||||
)
|
||||
|
||||
normalized_package_map: dict[str, str] = {}
|
||||
for dependency_name, index_package_name in package_map.items():
|
||||
if not isinstance(dependency_name, str) or not dependency_name.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings"
|
||||
)
|
||||
if not isinstance(index_package_name, str) or not index_package_name.strip():
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings"
|
||||
)
|
||||
canonical_dependency_name = canonicalize_name(dependency_name)
|
||||
if canonical_dependency_name not in normalized_packages:
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in "
|
||||
"[tool.comfy.isolation.cuda_wheels.packages]"
|
||||
)
|
||||
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
|
||||
|
||||
result: dict = {
|
||||
"packages": normalized_packages,
|
||||
"package_map": normalized_package_map,
|
||||
}
|
||||
if index_urls is not None:
|
||||
result["index_urls"] = [u.rstrip("/") + "/" for u in index_urls]
|
||||
else:
|
||||
result["index_url"] = index_url.rstrip("/") + "/"
|
||||
return result
|
||||
|
||||
|
||||
def get_enforcement_policy() -> Dict[str, bool]:
|
||||
return {
|
||||
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
|
||||
"force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1",
|
||||
}
|
||||
|
||||
|
||||
class ExtensionLoadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def register_dummy_module(extension_name: str, node_dir: Path) -> None:
|
||||
normalized_name = extension_name.replace("-", "_").replace(".", "_")
|
||||
if normalized_name not in sys.modules:
|
||||
dummy_module = types.ModuleType(normalized_name)
|
||||
dummy_module.__file__ = str(node_dir / "__init__.py")
|
||||
dummy_module.__path__ = [str(node_dir)]
|
||||
dummy_module.__package__ = normalized_name
|
||||
sys.modules[normalized_name] = dummy_module
|
||||
|
||||
|
||||
def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool:
|
||||
for details in cached_data.values():
|
||||
if not isinstance(details, dict):
|
||||
return True
|
||||
if details.get("is_v3") and "schema_v1" not in details:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def load_isolated_node(
|
||||
node_dir: Path,
|
||||
manifest_path: Path,
|
||||
logger: logging.Logger,
|
||||
build_stub_class: Callable[[str, Dict[str, object], Any], type],
|
||||
venv_root: Path,
|
||||
extension_managers: List[ExtensionManager],
|
||||
) -> List[Tuple[str, str, type]]:
|
||||
try:
|
||||
with manifest_path.open("rb") as handle:
|
||||
manifest_data = tomllib.load(handle)
|
||||
except Exception as e:
|
||||
logger.warning(f"][ Failed to parse {manifest_path}: {e}")
|
||||
return []
|
||||
|
||||
# Parse [tool.comfy.isolation]
|
||||
tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {})
|
||||
can_isolate = tool_config.get("can_isolate", False)
|
||||
share_torch = tool_config.get("share_torch", False)
|
||||
package_manager = tool_config.get("package_manager", "uv")
|
||||
is_conda = package_manager == "conda"
|
||||
execution_model = tool_config.get("execution_model")
|
||||
if execution_model is None:
|
||||
execution_model = "sealed_worker" if is_conda else "host-coupled"
|
||||
|
||||
if "sealed_host_ro_paths" in tool_config:
|
||||
raise ValueError(
|
||||
"Manifest field 'sealed_host_ro_paths' is not allowed. "
|
||||
"Configure [tool.comfy.host].sealed_worker_ro_import_paths in host policy."
|
||||
)
|
||||
|
||||
# Conda-specific manifest fields
|
||||
conda_channels: list[str] = (
|
||||
tool_config.get("conda_channels", []) if is_conda else []
|
||||
)
|
||||
conda_dependencies: list[str] = (
|
||||
tool_config.get("conda_dependencies", []) if is_conda else []
|
||||
)
|
||||
conda_platforms: list[str] = (
|
||||
tool_config.get("conda_platforms", []) if is_conda else []
|
||||
)
|
||||
conda_python: str = (
|
||||
tool_config.get("conda_python", "*") if is_conda else "*"
|
||||
)
|
||||
|
||||
# Parse [project] dependencies
|
||||
project_config = manifest_data.get("project", {})
|
||||
dependencies = project_config.get("dependencies", [])
|
||||
if not isinstance(dependencies, list):
|
||||
dependencies = []
|
||||
|
||||
# Get extension name (default to folder name if not in project.name)
|
||||
extension_name = project_config.get("name", node_dir.name)
|
||||
|
||||
# LOGIC: Isolation Decision
|
||||
policy = get_enforcement_policy()
|
||||
isolated = can_isolate or policy["force_isolated"]
|
||||
|
||||
if not isolated:
|
||||
return []
|
||||
|
||||
import folder_paths
|
||||
|
||||
base_paths = [Path(folder_paths.base_path), node_dir]
|
||||
dependencies = [
|
||||
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
|
||||
for dep in dependencies
|
||||
]
|
||||
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
|
||||
|
||||
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
|
||||
extension_type = _get_extension_type(execution_model)
|
||||
manager: ExtensionManager = pyisolate.ExtensionManager(
|
||||
extension_type, manager_config
|
||||
)
|
||||
extension_managers.append(manager)
|
||||
|
||||
host_policy = load_host_policy(Path(folder_paths.base_path))
|
||||
|
||||
sandbox_config = {}
|
||||
is_linux = platform.system() == "Linux"
|
||||
|
||||
if is_conda:
|
||||
share_torch = False
|
||||
share_cuda_ipc = False
|
||||
else:
|
||||
share_cuda_ipc = share_torch and is_linux
|
||||
|
||||
if is_linux and isolated:
|
||||
sandbox_config = {
|
||||
"network": host_policy["allow_network"],
|
||||
"writable_paths": host_policy["writable_paths"],
|
||||
"readonly_paths": host_policy["readonly_paths"],
|
||||
}
|
||||
|
||||
extension_config: dict = {
|
||||
"name": extension_name,
|
||||
"module_path": str(node_dir),
|
||||
"isolated": True,
|
||||
"dependencies": dependencies,
|
||||
"share_torch": share_torch,
|
||||
"share_cuda_ipc": share_cuda_ipc,
|
||||
"sandbox_mode": host_policy["sandbox_mode"],
|
||||
"sandbox": sandbox_config,
|
||||
}
|
||||
|
||||
_is_sealed = execution_model == "sealed_worker"
|
||||
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
|
||||
logger.info(
|
||||
"][ Loading isolated node: %s (torch_share [%s], sealed [%s], sandboxed [%s])",
|
||||
extension_name,
|
||||
"x" if share_torch else " ",
|
||||
"x" if _is_sealed else " ",
|
||||
"x" if _is_sandboxed else " ",
|
||||
)
|
||||
|
||||
if cuda_wheels is not None:
|
||||
extension_config["cuda_wheels"] = cuda_wheels
|
||||
|
||||
# Conda-specific keys
|
||||
if is_conda:
|
||||
extension_config["package_manager"] = "conda"
|
||||
extension_config["conda_channels"] = conda_channels
|
||||
extension_config["conda_dependencies"] = conda_dependencies
|
||||
extension_config["conda_python"] = conda_python
|
||||
find_links = tool_config.get("find_links", [])
|
||||
if find_links:
|
||||
extension_config["find_links"] = find_links
|
||||
if conda_platforms:
|
||||
extension_config["conda_platforms"] = conda_platforms
|
||||
|
||||
if execution_model != "host-coupled":
|
||||
extension_config["execution_model"] = execution_model
|
||||
if execution_model == "sealed_worker":
|
||||
policy_ro_paths = host_policy.get("sealed_worker_ro_import_paths", [])
|
||||
if isinstance(policy_ro_paths, list) and policy_ro_paths:
|
||||
extension_config["sealed_host_ro_paths"] = list(policy_ro_paths)
|
||||
# Sealed workers keep the host RPC service inventory even when the
|
||||
# child resolves no API classes locally.
|
||||
|
||||
extension = manager.load_extension(extension_config)
|
||||
register_dummy_module(extension_name, node_dir)
|
||||
|
||||
# Register host-side event handlers via adapter
|
||||
from .adapter import ComfyUIAdapter
|
||||
ComfyUIAdapter.register_host_event_handlers(extension)
|
||||
|
||||
# Register web directory on the host — only when sandbox is disabled.
|
||||
# In sandbox mode, serving untrusted JS to the browser is not safe.
|
||||
if host_policy["sandbox_mode"] == "disabled":
|
||||
_register_web_directory(extension_name, node_dir)
|
||||
|
||||
# Register for proxied web serving — the child's web dir may have
|
||||
# content that doesn't exist on the host (e.g., pip-installed viewer
|
||||
# bundles). The WebDirectoryCache will lazily fetch via RPC.
|
||||
from .proxies.web_directory_proxy import WebDirectoryProxy, get_web_directory_cache
|
||||
cache = get_web_directory_cache()
|
||||
cache.register_proxy(extension_name, WebDirectoryProxy())
|
||||
|
||||
# Try cache first (lazy spawn)
|
||||
if is_cache_valid(node_dir, manifest_path, venv_root):
|
||||
cached_data = load_from_cache(node_dir, venv_root)
|
||||
if cached_data:
|
||||
if _is_stale_node_cache(cached_data):
|
||||
logger.debug(
|
||||
"][ %s cache is stale/incompatible; rebuilding metadata",
|
||||
extension_name,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"][ {extension_name} loaded from cache")
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
for node_name, details in cached_data.items():
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append(
|
||||
(node_name, details.get("display_name", node_name), stub_cls)
|
||||
)
|
||||
return specs
|
||||
|
||||
# Cache miss - spawn process and get metadata
|
||||
logger.debug(f"][ {extension_name} cache miss, spawning process for metadata")
|
||||
|
||||
try:
|
||||
remote_nodes: Dict[str, str] = await extension.list_nodes()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s metadata discovery failed, skipping isolated load: %s",
|
||||
extension_name,
|
||||
exc,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
if not remote_nodes:
|
||||
logger.debug("][ %s exposed no isolated nodes; skipping", extension_name)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
cache_data: Dict[str, Dict] = {}
|
||||
|
||||
for node_name, display_name in remote_nodes.items():
|
||||
try:
|
||||
details = await extension.get_node_details(node_name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"][ %s failed to load metadata for %s, skipping node: %s",
|
||||
extension_name,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
details["display_name"] = display_name
|
||||
cache_data[node_name] = details
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
specs.append((node_name, display_name, stub_cls))
|
||||
|
||||
if not specs:
|
||||
logger.warning(
|
||||
"][ %s produced no usable nodes after metadata scan; skipping",
|
||||
extension_name,
|
||||
)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
return []
|
||||
|
||||
# Save metadata to cache for future runs
|
||||
save_to_cache(node_dir, venv_root, cache_data, manifest_path)
|
||||
logger.debug(f"][ {extension_name} metadata cached")
|
||||
|
||||
# Re-check web directory AFTER child has populated it
|
||||
if host_policy["sandbox_mode"] == "disabled":
|
||||
_register_web_directory(extension_name, node_dir)
|
||||
|
||||
# EJECT: Kill process after getting metadata (will respawn on first execution)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
|
||||
return specs
|
||||
|
||||
|
||||
__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"]
|
||||
878
comfy/isolation/extension_wrapper.py
Normal file
878
comfy/isolation/extension_wrapper.py
Normal file
@@ -0,0 +1,878 @@
|
||||
# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError as e:
|
||||
raise AttributeError(item) from e
|
||||
|
||||
def copy(self):
|
||||
return AttrDict(super().copy())
|
||||
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pyisolate import ExtensionBase
|
||||
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
V3_DISCOVERY_TIMEOUT = 30
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _run_prestartup_web_copy(module: Any, module_dir: str, web_dir_path: str) -> None:
|
||||
"""Run the web asset copy step that prestartup_script.py used to do.
|
||||
|
||||
If the module's web/ directory is empty and the module had a
|
||||
prestartup_script.py that copied assets from pip packages, this
|
||||
function replicates that work inside the child process.
|
||||
|
||||
Generic pattern: reads _PRESTARTUP_WEB_COPY from the module if
|
||||
defined, otherwise falls back to detecting common asset packages.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Already populated — nothing to do
|
||||
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||
return
|
||||
|
||||
os.makedirs(web_dir_path, exist_ok=True)
|
||||
|
||||
# Try module-defined copy spec first (generic hook for any node pack)
|
||||
copy_spec = getattr(module, "_PRESTARTUP_WEB_COPY", None)
|
||||
if copy_spec is not None and callable(copy_spec):
|
||||
try:
|
||||
copy_spec(web_dir_path)
|
||||
logger.info(
|
||||
"%s Ran _PRESTARTUP_WEB_COPY for %s", LOG_PREFIX, module_dir
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"%s _PRESTARTUP_WEB_COPY failed for %s: %s",
|
||||
LOG_PREFIX, module_dir, e,
|
||||
)
|
||||
|
||||
# Fallback: detect comfy_3d_viewers and run copy_viewer()
|
||||
try:
|
||||
from comfy_3d_viewers import copy_viewer, VIEWER_FILES
|
||||
viewers = list(VIEWER_FILES.keys())
|
||||
for viewer in viewers:
|
||||
try:
|
||||
copy_viewer(viewer, web_dir_path)
|
||||
except Exception:
|
||||
pass
|
||||
if any(os.scandir(web_dir_path)):
|
||||
logger.info(
|
||||
"%s Copied %d viewer types from comfy_3d_viewers to %s",
|
||||
LOG_PREFIX, len(viewers), web_dir_path,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback: detect comfy_dynamic_widgets
|
||||
try:
|
||||
from comfy_dynamic_widgets import get_js_path
|
||||
src = os.path.realpath(get_js_path())
|
||||
if os.path.exists(src):
|
||||
dst_dir = os.path.join(web_dir_path, "js")
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
dst = os.path.join(dst_dir, "dynamic_widgets.js")
|
||||
shutil.copy2(src, dst)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _read_extension_name(module_dir: str) -> str:
|
||||
"""Read extension name from pyproject.toml, falling back to directory name."""
|
||||
pyproject = os.path.join(module_dir, "pyproject.toml")
|
||||
if os.path.exists(pyproject):
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
try:
|
||||
with open(pyproject, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
name = data.get("project", {}).get("name")
|
||||
if name:
|
||||
return name
|
||||
except Exception:
|
||||
pass
|
||||
return os.path.basename(module_dir)
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str) -> int:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return 0
|
||||
if not callable(flush_tensor_keeper):
|
||||
return 0
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
return flushed
|
||||
|
||||
|
||||
def _relieve_child_vram_pressure(marker: str) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _sanitize_for_transport(value):
|
||||
primitives = (str, int, float, bool, type(None))
|
||||
if isinstance(value, primitives):
|
||||
return value
|
||||
|
||||
cls_name = value.__class__.__name__
|
||||
if cls_name == "FlexibleOptionalInputType":
|
||||
return {
|
||||
"__pyisolate_flexible_optional__": True,
|
||||
"type": _sanitize_for_transport(getattr(value, "type", "*")),
|
||||
}
|
||||
if cls_name == "AnyType":
|
||||
return {"__pyisolate_any_type__": True, "value": str(value)}
|
||||
if cls_name == "ByPassTypeTuple":
|
||||
return {
|
||||
"__pyisolate_bypass_tuple__": [
|
||||
_sanitize_for_transport(v) for v in tuple(value)
|
||||
]
|
||||
}
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {k: _sanitize_for_transport(v) for k, v in value.items()}
|
||||
if isinstance(value, tuple):
|
||||
return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]}
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_for_transport(v) for v in value]
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
# Re-export RemoteObjectHandle from pyisolate for backward compatibility
|
||||
# The canonical definition is now in pyisolate._internal.remote_handle
|
||||
from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401
|
||||
|
||||
|
||||
class ComfyNodeExtension(ExtensionBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.node_classes: Dict[str, type] = {}
|
||||
self.display_names: Dict[str, str] = {}
|
||||
self.node_instances: Dict[str, Any] = {}
|
||||
self.remote_objects: Dict[str, Any] = {}
|
||||
self._route_handlers: Dict[str, Any] = {}
|
||||
self._module: Any = None
|
||||
|
||||
async def on_module_loaded(self, module: Any) -> None:
|
||||
self._module = module
|
||||
|
||||
# Registries are initialized in host_hooks.py initialize_host_process()
|
||||
# They auto-register via ProxiedSingleton when instantiated
|
||||
# NO additional setup required here - if a registry is missing from host_hooks, it WILL fail
|
||||
|
||||
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
|
||||
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
|
||||
|
||||
# Register web directory with WebDirectoryProxy (child-side)
|
||||
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
|
||||
if web_dir_attr is not None:
|
||||
module_dir = os.path.dirname(os.path.abspath(module.__file__))
|
||||
web_dir_path = os.path.abspath(os.path.join(module_dir, web_dir_attr))
|
||||
ext_name = _read_extension_name(module_dir)
|
||||
|
||||
# If web dir is empty, run the copy step that prestartup_script.py did
|
||||
_run_prestartup_web_copy(module, module_dir, web_dir_path)
|
||||
|
||||
if os.path.isdir(web_dir_path) and any(os.scandir(web_dir_path)):
|
||||
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
|
||||
WebDirectoryProxy.register_web_dir(ext_name, web_dir_path)
|
||||
|
||||
try:
|
||||
from comfy_api.latest import ComfyExtension
|
||||
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if not (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, ComfyExtension)
|
||||
and obj is not ComfyExtension
|
||||
):
|
||||
continue
|
||||
if not obj.__module__.startswith(module.__name__):
|
||||
continue
|
||||
try:
|
||||
ext_instance = obj()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in on_load()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
try:
|
||||
v3_nodes = await asyncio.wait_for(
|
||||
ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"%s V3 Extension %s timed out in get_node_list()",
|
||||
LOG_PREFIX,
|
||||
name,
|
||||
)
|
||||
continue
|
||||
for node_cls in v3_nodes:
|
||||
if hasattr(node_cls, "GET_SCHEMA"):
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
self.node_classes[schema.node_id] = node_cls
|
||||
if schema.display_name:
|
||||
self.display_names[schema.node_id] = schema.display_name
|
||||
except Exception as e:
|
||||
logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
module_name = getattr(module, "__name__", "isolated_nodes")
|
||||
for node_cls in self.node_classes.values():
|
||||
if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__):
|
||||
node_cls.__module__ = module_name
|
||||
|
||||
self.node_instances = {}
|
||||
|
||||
async def list_nodes(self) -> Dict[str, str]:
|
||||
return {name: self.display_names.get(name, name) for name in self.node_classes}
|
||||
|
||||
async def get_node_info(self, node_name: str) -> Dict[str, Any]:
|
||||
return await self.get_node_details(node_name)
|
||||
|
||||
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
|
||||
|
||||
input_types_raw = (
|
||||
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
|
||||
)
|
||||
output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None)
|
||||
if output_is_list is not None:
|
||||
output_is_list = tuple(bool(x) for x in output_is_list)
|
||||
|
||||
details: Dict[str, Any] = {
|
||||
"input_types": _sanitize_for_transport(input_types_raw),
|
||||
"return_types": tuple(
|
||||
str(t) for t in getattr(node_cls, "RETURN_TYPES", ())
|
||||
),
|
||||
"return_names": getattr(node_cls, "RETURN_NAMES", None),
|
||||
"function": str(getattr(node_cls, "FUNCTION", "execute")),
|
||||
"category": str(getattr(node_cls, "CATEGORY", "")),
|
||||
"output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)),
|
||||
"output_is_list": output_is_list,
|
||||
"is_v3": is_v3,
|
||||
}
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
schema_v1 = asdict(schema.get_v1_info(node_cls))
|
||||
try:
|
||||
schema_v3 = asdict(schema.get_v3_info(node_cls))
|
||||
except (AttributeError, TypeError):
|
||||
schema_v3 = self._build_schema_v3_fallback(schema)
|
||||
details.update(
|
||||
{
|
||||
"schema_v1": schema_v1,
|
||||
"schema_v3": schema_v3,
|
||||
"hidden": [h.value for h in (schema.hidden or [])],
|
||||
"description": getattr(schema, "description", ""),
|
||||
"deprecated": bool(getattr(node_cls, "DEPRECATED", False)),
|
||||
"experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)),
|
||||
"api_node": bool(getattr(node_cls, "API_NODE", False)),
|
||||
"input_is_list": bool(
|
||||
getattr(node_cls, "INPUT_IS_LIST", False)
|
||||
),
|
||||
"not_idempotent": bool(
|
||||
getattr(node_cls, "NOT_IDEMPOTENT", False)
|
||||
),
|
||||
"accept_all_inputs": bool(
|
||||
getattr(node_cls, "ACCEPT_ALL_INPUTS", False)
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"%s V3 schema serialization failed for %s: %s",
|
||||
LOG_PREFIX,
|
||||
node_name,
|
||||
exc,
|
||||
)
|
||||
return details
|
||||
|
||||
def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]:
|
||||
input_dict: Dict[str, Any] = {}
|
||||
output_dict: Dict[str, Any] = {}
|
||||
hidden_list: List[str] = []
|
||||
|
||||
if getattr(schema, "inputs", None):
|
||||
for inp in schema.inputs:
|
||||
self._add_schema_io_v3(inp, input_dict)
|
||||
if getattr(schema, "outputs", None):
|
||||
for out in schema.outputs:
|
||||
self._add_schema_io_v3(out, output_dict)
|
||||
if getattr(schema, "hidden", None):
|
||||
for h in schema.hidden:
|
||||
hidden_list.append(getattr(h, "value", str(h)))
|
||||
|
||||
return {
|
||||
"input": input_dict,
|
||||
"output": output_dict,
|
||||
"hidden": hidden_list,
|
||||
"name": getattr(schema, "node_id", None),
|
||||
"display_name": getattr(schema, "display_name", None),
|
||||
"description": getattr(schema, "description", None),
|
||||
"category": getattr(schema, "category", None),
|
||||
"output_node": getattr(schema, "is_output_node", False),
|
||||
"deprecated": getattr(schema, "is_deprecated", False),
|
||||
"experimental": getattr(schema, "is_experimental", False),
|
||||
"api_node": getattr(schema, "is_api_node", False),
|
||||
}
|
||||
|
||||
def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None:
|
||||
io_id = getattr(io_obj, "id", None)
|
||||
if io_id is None:
|
||||
return
|
||||
|
||||
io_type_fn = getattr(io_obj, "get_io_type", None)
|
||||
io_type = (
|
||||
io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None)
|
||||
)
|
||||
|
||||
as_dict_fn = getattr(io_obj, "as_dict", None)
|
||||
payload = as_dict_fn() if callable(as_dict_fn) else {}
|
||||
|
||||
target[str(io_id)] = (io_type, payload)
|
||||
|
||||
async def get_input_types(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
if hasattr(node_cls, "INPUT_TYPES"):
|
||||
return node_cls.INPUT_TYPES()
|
||||
return {}
|
||||
|
||||
async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]:
|
||||
logger.debug(
|
||||
"%s ISO:child_execute_start ext=%s node=%s input_keys=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
len(inputs),
|
||||
)
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
_relieve_child_vram_pressure("EXT:pre_execute")
|
||||
|
||||
resolved_inputs = self._resolve_remote_objects(inputs)
|
||||
|
||||
instance = self._get_node_instance(node_name)
|
||||
node_cls = self._get_node_class(node_name)
|
||||
|
||||
# V3 API nodes expect hidden parameters in cls.hidden, not as kwargs
|
||||
# Hidden params come through RPC as string keys like "Hidden.prompt"
|
||||
from comfy_api.latest._io import Hidden, HiddenHolder
|
||||
|
||||
# Map string representations back to Hidden enum keys
|
||||
hidden_string_map = {
|
||||
"Hidden.unique_id": Hidden.unique_id,
|
||||
"Hidden.prompt": Hidden.prompt,
|
||||
"Hidden.extra_pnginfo": Hidden.extra_pnginfo,
|
||||
"Hidden.dynprompt": Hidden.dynprompt,
|
||||
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
|
||||
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
|
||||
# Uppercase enum VALUE forms — V3 execution engine passes these
|
||||
"UNIQUE_ID": Hidden.unique_id,
|
||||
"PROMPT": Hidden.prompt,
|
||||
"EXTRA_PNGINFO": Hidden.extra_pnginfo,
|
||||
"DYNPROMPT": Hidden.dynprompt,
|
||||
"AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org,
|
||||
"API_KEY_COMFY_ORG": Hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Find and extract hidden parameters (both enum and string form)
|
||||
hidden_found = {}
|
||||
keys_to_remove = []
|
||||
|
||||
for key in list(resolved_inputs.keys()):
|
||||
# Check string form first (from RPC serialization)
|
||||
if key in hidden_string_map:
|
||||
hidden_found[hidden_string_map[key]] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
# Also check enum form (direct calls)
|
||||
elif isinstance(key, Hidden):
|
||||
hidden_found[key] = resolved_inputs[key]
|
||||
keys_to_remove.append(key)
|
||||
|
||||
# Remove hidden params from kwargs
|
||||
for key in keys_to_remove:
|
||||
resolved_inputs.pop(key)
|
||||
|
||||
# Set hidden on node class if any hidden params found
|
||||
if hidden_found:
|
||||
if not hasattr(node_cls, "hidden") or node_cls.hidden is None:
|
||||
node_cls.hidden = HiddenHolder.from_dict(hidden_found)
|
||||
else:
|
||||
# Update existing hidden holder
|
||||
for key, value in hidden_found.items():
|
||||
setattr(node_cls.hidden, key.value.lower(), value)
|
||||
|
||||
# INPUT_IS_LIST: ComfyUI's executor passes all inputs as lists when this
|
||||
# flag is set. The isolation RPC delivers unwrapped values, so we must
|
||||
# wrap each input in a single-element list to match the contract.
|
||||
if getattr(node_cls, "INPUT_IS_LIST", False):
|
||||
resolved_inputs = {k: [v] for k, v in resolved_inputs.items()}
|
||||
|
||||
function_name = getattr(node_cls, "FUNCTION", "execute")
|
||||
if not hasattr(instance, function_name):
|
||||
raise AttributeError(f"Node {node_name} missing callable '{function_name}'")
|
||||
|
||||
handler = getattr(instance, function_name)
|
||||
|
||||
try:
|
||||
import torch
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
with torch.inference_mode():
|
||||
result = await handler(**resolved_inputs)
|
||||
else:
|
||||
import functools
|
||||
|
||||
def _run_with_inference_mode(**kwargs):
|
||||
with torch.inference_mode():
|
||||
return handler(**kwargs)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, functools.partial(_run_with_inference_mode, **resolved_inputs)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:child_execute_error ext=%s node=%s",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
node_name,
|
||||
)
|
||||
raise
|
||||
|
||||
if type(result).__name__ == "NodeOutput":
|
||||
node_output_dict = {
|
||||
"__node_output__": True,
|
||||
"args": self._wrap_unpicklable_objects(result.args),
|
||||
}
|
||||
if result.ui is not None:
|
||||
node_output_dict["ui"] = self._wrap_unpicklable_objects(result.ui)
|
||||
if getattr(result, "expand", None) is not None:
|
||||
node_output_dict["expand"] = result.expand
|
||||
if getattr(result, "block_execution", None) is not None:
|
||||
node_output_dict["block_execution"] = result.block_execution
|
||||
return node_output_dict
|
||||
if self._is_comfy_protocol_return(result):
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
async def flush_transport_state(self) -> int:
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
return 0
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
|
||||
)
|
||||
flushed = _flush_tensor_transport_state("EXT:workflow_end")
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
)
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
removed = registry.sweep_pending_cleanup()
|
||||
if removed > 0:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:child_flush_done ext=%s flushed=%d",
|
||||
LOG_PREFIX,
|
||||
getattr(self, "name", "?"),
|
||||
flushed,
|
||||
)
|
||||
return flushed
|
||||
|
||||
async def get_remote_object(self, object_id: str) -> Any:
|
||||
"""Retrieve a remote object by ID for host-side deserialization."""
|
||||
if object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {object_id} not found")
|
||||
|
||||
return self.remote_objects[object_id]
|
||||
|
||||
def _store_remote_object_handle(self, obj: Any) -> RemoteObjectHandle:
|
||||
object_id = str(uuid.uuid4())
|
||||
self.remote_objects[object_id] = obj
|
||||
return RemoteObjectHandle(object_id, type(obj).__name__)
|
||||
|
||||
async def call_remote_object_method(
|
||||
self,
|
||||
object_id: str,
|
||||
method_name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Invoke a method or attribute-backed accessor on a child-owned object."""
|
||||
obj = await self.get_remote_object(object_id)
|
||||
|
||||
if method_name == "get_patcher_attr":
|
||||
return getattr(obj, args[0])
|
||||
if method_name == "get_model_options":
|
||||
return getattr(obj, "model_options")
|
||||
if method_name == "set_model_options":
|
||||
setattr(obj, "model_options", args[0])
|
||||
return None
|
||||
if method_name == "get_object_patches":
|
||||
return getattr(obj, "object_patches")
|
||||
if method_name == "get_patches":
|
||||
return getattr(obj, "patches")
|
||||
if method_name == "get_wrappers":
|
||||
return getattr(obj, "wrappers")
|
||||
if method_name == "get_callbacks":
|
||||
return getattr(obj, "callbacks")
|
||||
if method_name == "get_load_device":
|
||||
return getattr(obj, "load_device")
|
||||
if method_name == "get_offload_device":
|
||||
return getattr(obj, "offload_device")
|
||||
if method_name == "get_hook_mode":
|
||||
return getattr(obj, "hook_mode")
|
||||
if method_name == "get_parent":
|
||||
parent = getattr(obj, "parent", None)
|
||||
if parent is None:
|
||||
return None
|
||||
return self._store_remote_object_handle(parent)
|
||||
if method_name == "get_inner_model_attr":
|
||||
attr_name = args[0]
|
||||
if hasattr(obj.model, attr_name):
|
||||
return getattr(obj.model, attr_name)
|
||||
if hasattr(obj, attr_name):
|
||||
return getattr(obj, attr_name)
|
||||
return None
|
||||
if method_name == "inner_model_apply_model":
|
||||
return obj.model.apply_model(*args[0], **args[1])
|
||||
if method_name == "inner_model_extra_conds_shapes":
|
||||
return obj.model.extra_conds_shapes(*args[0], **args[1])
|
||||
if method_name == "inner_model_extra_conds":
|
||||
return obj.model.extra_conds(*args[0], **args[1])
|
||||
if method_name == "inner_model_memory_required":
|
||||
return obj.model.memory_required(*args[0], **args[1])
|
||||
if method_name == "process_latent_in":
|
||||
return obj.model.process_latent_in(*args[0], **args[1])
|
||||
if method_name == "process_latent_out":
|
||||
return obj.model.process_latent_out(*args[0], **args[1])
|
||||
if method_name == "scale_latent_inpaint":
|
||||
return obj.model.scale_latent_inpaint(*args[0], **args[1])
|
||||
if method_name.startswith("get_"):
|
||||
attr_name = method_name[4:]
|
||||
if hasattr(obj, attr_name):
|
||||
return getattr(obj, attr_name)
|
||||
|
||||
target = getattr(obj, method_name)
|
||||
if callable(target):
|
||||
result = target(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if type(result).__name__ == "ModelPatcher":
|
||||
return self._store_remote_object_handle(result)
|
||||
return result
|
||||
if args or kwargs:
|
||||
raise TypeError(f"{method_name} is not callable on remote object {object_id}")
|
||||
return target
|
||||
|
||||
def _wrap_unpicklable_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, (str, int, float, bool, type(None))):
|
||||
return data
|
||||
if isinstance(data, torch.Tensor):
|
||||
tensor = data.detach() if data.requires_grad else data
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
|
||||
return tensor.cpu()
|
||||
return tensor
|
||||
|
||||
# Special-case clip vision outputs: preserve attribute access by packing fields
|
||||
if hasattr(data, "penultimate_hidden_states") or hasattr(
|
||||
data, "last_hidden_state"
|
||||
):
|
||||
fields = {}
|
||||
for attr in (
|
||||
"penultimate_hidden_states",
|
||||
"last_hidden_state",
|
||||
"image_embeds",
|
||||
"text_embeds",
|
||||
):
|
||||
if hasattr(data, attr):
|
||||
try:
|
||||
fields[attr] = self._wrap_unpicklable_objects(
|
||||
getattr(data, attr)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if fields:
|
||||
return {"__pyisolate_attribute_container__": True, "data": fields}
|
||||
|
||||
# Avoid converting arbitrary objects with stateful methods (models, etc.)
|
||||
# They will be handled via RemoteObjectHandle below.
|
||||
|
||||
type_name = type(data).__name__
|
||||
if type_name == "ModelPatcherProxy":
|
||||
return {"__type__": "ModelPatcherRef", "model_id": data._instance_id}
|
||||
if type_name == "CLIPProxy":
|
||||
return {"__type__": "CLIPRef", "clip_id": data._instance_id}
|
||||
if type_name == "VAEProxy":
|
||||
return {"__type__": "VAERef", "vae_id": data._instance_id}
|
||||
if type_name == "ModelSamplingProxy":
|
||||
return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
wrapped = [self._wrap_unpicklable_objects(item) for item in data]
|
||||
return tuple(wrapped) if isinstance(data, tuple) else wrapped
|
||||
if isinstance(data, dict):
|
||||
converted_dict = {
|
||||
k: self._wrap_unpicklable_objects(v) for k, v in data.items()
|
||||
}
|
||||
return {"__pyisolate_attrdict__": True, "data": converted_dict}
|
||||
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
|
||||
registry = SerializerRegistry.get_instance()
|
||||
if registry.is_data_type(type_name):
|
||||
serializer = registry.get_serializer(type_name)
|
||||
if serializer:
|
||||
return serializer(data)
|
||||
|
||||
return self._store_remote_object_handle(data)
|
||||
|
||||
def _resolve_remote_objects(self, data: Any) -> Any:
|
||||
if isinstance(data, RemoteObjectHandle):
|
||||
if data.object_id not in self.remote_objects:
|
||||
raise KeyError(f"Remote object {data.object_id} not found")
|
||||
return self.remote_objects[data.object_id]
|
||||
|
||||
if isinstance(data, dict):
|
||||
ref_type = data.get("__type__")
|
||||
if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"):
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
if ref_type == "ModelSamplingRef":
|
||||
from pyisolate._internal.model_serialization import (
|
||||
deserialize_proxy_result,
|
||||
)
|
||||
|
||||
return deserialize_proxy_result(data)
|
||||
return {k: self._resolve_remote_objects(v) for k, v in data.items()}
|
||||
|
||||
if isinstance(data, (list, tuple)):
|
||||
resolved = [self._resolve_remote_objects(item) for item in data]
|
||||
return tuple(resolved) if isinstance(data, tuple) else resolved
|
||||
return data
|
||||
|
||||
def _get_node_class(self, node_name: str) -> type:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
return self.node_classes[node_name]
|
||||
|
||||
def _get_node_instance(self, node_name: str) -> Any:
|
||||
if node_name not in self.node_instances:
|
||||
if node_name not in self.node_classes:
|
||||
raise KeyError(f"Unknown node: {node_name}")
|
||||
self.node_instances[node_name] = self.node_classes[node_name]()
|
||||
return self.node_instances[node_name]
|
||||
|
||||
async def before_module_loaded(self) -> None:
|
||||
# Inject initialization here if we think this is the child
|
||||
try:
|
||||
from comfy.isolation import initialize_proxies
|
||||
|
||||
initialize_proxies()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).error(
|
||||
f"Failed to call initialize_proxies in before_module_loaded: {e}"
|
||||
)
|
||||
|
||||
await super().before_module_loaded()
|
||||
try:
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
|
||||
ComfyAPI_latest.Execution = ProgressProxy
|
||||
# ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision
|
||||
# fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision
|
||||
# latest_ui.folder_paths = fp_proxy
|
||||
# latest_resources.folder_paths = fp_proxy
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def call_route_handler(
|
||||
self,
|
||||
handler_module: str,
|
||||
handler_func: str,
|
||||
request_data: Dict[str, Any],
|
||||
) -> Any:
|
||||
cache_key = f"{handler_module}.{handler_func}"
|
||||
if cache_key not in self._route_handlers:
|
||||
if self._module is not None and hasattr(self._module, "__file__"):
|
||||
node_dir = os.path.dirname(self._module.__file__)
|
||||
if node_dir not in sys.path:
|
||||
sys.path.insert(0, node_dir)
|
||||
try:
|
||||
module = importlib.import_module(handler_module)
|
||||
self._route_handlers[cache_key] = getattr(module, handler_func)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(f"Route handler not found: {cache_key}") from e
|
||||
|
||||
handler = self._route_handlers[cache_key]
|
||||
mock_request = MockRequest(request_data)
|
||||
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(mock_request)
|
||||
else:
|
||||
result = handler(mock_request)
|
||||
return self._serialize_response(result)
|
||||
|
||||
def _is_comfy_protocol_return(self, result: Any) -> bool:
|
||||
"""
|
||||
Check if the result matches the ComfyUI 'Protocol Return' schema.
|
||||
|
||||
A Protocol Return is a dictionary containing specific reserved keys that
|
||||
ComfyUI's execution engine interprets as instructions (UI updates,
|
||||
Workflow expansion, etc.) rather than purely data outputs.
|
||||
|
||||
Schema:
|
||||
- Must be a dict
|
||||
- Must contain at least one of: 'ui', 'result', 'expand'
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return False
|
||||
return any(key in result for key in ("ui", "result", "expand"))
|
||||
|
||||
def _serialize_response(self, response: Any) -> Dict[str, Any]:
|
||||
if response is None:
|
||||
return {"type": "text", "body": "", "status": 204}
|
||||
if isinstance(response, dict):
|
||||
return {"type": "json", "body": response, "status": 200}
|
||||
if isinstance(response, str):
|
||||
return {"type": "text", "body": response, "status": 200}
|
||||
if hasattr(response, "text") and hasattr(response, "status"):
|
||||
return {
|
||||
"type": "text",
|
||||
"body": response.text
|
||||
if hasattr(response, "text")
|
||||
else str(response.body),
|
||||
"status": response.status,
|
||||
"headers": dict(response.headers)
|
||||
if hasattr(response, "headers")
|
||||
else {},
|
||||
}
|
||||
if hasattr(response, "body") and hasattr(response, "status"):
|
||||
body = response.body
|
||||
if isinstance(body, bytes):
|
||||
try:
|
||||
return {
|
||||
"type": "text",
|
||||
"body": body.decode("utf-8"),
|
||||
"status": response.status,
|
||||
}
|
||||
except UnicodeDecodeError:
|
||||
return {
|
||||
"type": "binary",
|
||||
"body": body.hex(),
|
||||
"status": response.status,
|
||||
}
|
||||
return {"type": "json", "body": body, "status": response.status}
|
||||
return {"type": "text", "body": str(response), "status": 200}
|
||||
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.method = data.get("method", "GET")
|
||||
self.path = data.get("path", "/")
|
||||
self.query = data.get("query", {})
|
||||
self._body = data.get("body", {})
|
||||
self._text = data.get("text", "")
|
||||
self.headers = data.get("headers", {})
|
||||
self.content_type = data.get(
|
||||
"content_type", self.headers.get("Content-Type", "application/json")
|
||||
)
|
||||
self.match_info = data.get("match_info", {})
|
||||
|
||||
async def json(self) -> Any:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
if isinstance(self._body, str):
|
||||
return json.loads(self._body)
|
||||
return {}
|
||||
|
||||
async def post(self) -> Dict[str, Any]:
|
||||
if isinstance(self._body, dict):
|
||||
return self._body
|
||||
return {}
|
||||
|
||||
async def text(self) -> str:
|
||||
if self._text:
|
||||
return self._text
|
||||
if isinstance(self._body, str):
|
||||
return self._body
|
||||
if isinstance(self._body, dict):
|
||||
return json.dumps(self._body)
|
||||
return ""
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return (await self.text()).encode("utf-8")
|
||||
30
comfy/isolation/host_hooks.py
Normal file
30
comfy/isolation/host_hooks.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# Host process initialization for PyIsolate
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_host_process() -> None:
|
||||
root = logging.getLogger()
|
||||
for handler in root.handlers[:]:
|
||||
root.removeHandler(handler)
|
||||
root.addHandler(logging.NullHandler())
|
||||
|
||||
from .proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from .proxies.helper_proxies import HelperProxiesService
|
||||
from .proxies.model_management_proxy import ModelManagementProxy
|
||||
from .proxies.progress_proxy import ProgressProxy
|
||||
from .proxies.prompt_server_impl import PromptServerService
|
||||
from .proxies.utils_proxy import UtilsProxy
|
||||
from .proxies.web_directory_proxy import WebDirectoryProxy
|
||||
from .vae_proxy import VAERegistry
|
||||
|
||||
FolderPathsProxy()
|
||||
HelperProxiesService()
|
||||
ModelManagementProxy()
|
||||
ProgressProxy()
|
||||
PromptServerService()
|
||||
UtilsProxy()
|
||||
WebDirectoryProxy()
|
||||
VAERegistry()
|
||||
178
comfy/isolation/host_policy.py
Normal file
178
comfy/isolation/host_policy.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# pylint: disable=logging-fstring-interpolation
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Dict, List, TypedDict
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
|
||||
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
|
||||
FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"})
|
||||
|
||||
|
||||
class HostSecurityPolicy(TypedDict):
|
||||
sandbox_mode: str
|
||||
allow_network: bool
|
||||
writable_paths: List[str]
|
||||
readonly_paths: List[str]
|
||||
sealed_worker_ro_import_paths: List[str]
|
||||
whitelist: Dict[str, str]
|
||||
|
||||
|
||||
DEFAULT_POLICY: HostSecurityPolicy = {
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": ["/dev/shm"],
|
||||
"readonly_paths": [],
|
||||
"sealed_worker_ro_import_paths": [],
|
||||
"whitelist": {},
|
||||
}
|
||||
|
||||
|
||||
def _default_policy() -> HostSecurityPolicy:
|
||||
return {
|
||||
"sandbox_mode": DEFAULT_POLICY["sandbox_mode"],
|
||||
"allow_network": DEFAULT_POLICY["allow_network"],
|
||||
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
|
||||
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
|
||||
"sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]),
|
||||
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_writable_paths(paths: list[object]) -> list[str]:
|
||||
normalized_paths: list[str] = []
|
||||
for raw_path in paths:
|
||||
# Host-policy paths are contract-style POSIX paths; keep representation
|
||||
# stable across Windows/Linux so tests and config behavior stay consistent.
|
||||
normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/")))
|
||||
if normalized_path in FORBIDDEN_WRITABLE_PATHS:
|
||||
continue
|
||||
normalized_paths.append(normalized_path)
|
||||
return normalized_paths
|
||||
|
||||
|
||||
def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]:
|
||||
if not file_path.is_absolute():
|
||||
file_path = config_path.parent / file_path
|
||||
if not file_path.exists():
|
||||
logger.warning("whitelist_file %s not found, skipping.", file_path)
|
||||
return {}
|
||||
entries: Dict[str, str] = {}
|
||||
for line in file_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
entries[line] = "*"
|
||||
logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path)
|
||||
return entries
|
||||
|
||||
|
||||
def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]:
|
||||
if not isinstance(raw_paths, list):
|
||||
raise ValueError(
|
||||
"tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths."
|
||||
)
|
||||
|
||||
normalized_paths: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw_path in raw_paths:
|
||||
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||
raise ValueError(
|
||||
"tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings."
|
||||
)
|
||||
normalized_path = str(PurePosixPath(raw_path.replace("\\", "/")))
|
||||
# Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...)
|
||||
is_absolute = normalized_path.startswith("/") or (
|
||||
len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/"
|
||||
)
|
||||
if not is_absolute:
|
||||
raise ValueError(
|
||||
"tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths."
|
||||
)
|
||||
if normalized_path not in seen:
|
||||
seen.add(normalized_path)
|
||||
normalized_paths.append(normalized_path)
|
||||
|
||||
return normalized_paths
|
||||
|
||||
|
||||
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
||||
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
|
||||
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
|
||||
policy = _default_policy()
|
||||
|
||||
if not config_path.exists():
|
||||
logger.debug("Host policy file missing at %s, using defaults.", config_path)
|
||||
return policy
|
||||
|
||||
try:
|
||||
with config_path.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse host policy from %s, using defaults.",
|
||||
config_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return policy
|
||||
|
||||
tool_config = data.get("tool", {}).get("comfy", {}).get("host", {})
|
||||
if not isinstance(tool_config, dict):
|
||||
logger.debug("No [tool.comfy.host] section found, using defaults.")
|
||||
return policy
|
||||
|
||||
sandbox_mode = tool_config.get("sandbox_mode")
|
||||
if isinstance(sandbox_mode, str):
|
||||
normalized_sandbox_mode = sandbox_mode.strip().lower()
|
||||
if normalized_sandbox_mode in VALID_SANDBOX_MODES:
|
||||
policy["sandbox_mode"] = normalized_sandbox_mode
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid host sandbox_mode %r in %s, using default %r.",
|
||||
sandbox_mode,
|
||||
config_path,
|
||||
DEFAULT_POLICY["sandbox_mode"],
|
||||
)
|
||||
|
||||
if "allow_network" in tool_config:
|
||||
policy["allow_network"] = bool(tool_config["allow_network"])
|
||||
|
||||
if "writable_paths" in tool_config:
|
||||
policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"])
|
||||
|
||||
if "readonly_paths" in tool_config:
|
||||
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
|
||||
|
||||
if "sealed_worker_ro_import_paths" in tool_config:
|
||||
policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths(
|
||||
tool_config["sealed_worker_ro_import_paths"]
|
||||
)
|
||||
|
||||
whitelist_file = tool_config.get("whitelist_file")
|
||||
if isinstance(whitelist_file, str):
|
||||
policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path))
|
||||
|
||||
whitelist_raw = tool_config.get("whitelist")
|
||||
if isinstance(whitelist_raw, dict):
|
||||
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
|
||||
|
||||
logger.debug(
|
||||
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
|
||||
len(policy["whitelist"]),
|
||||
policy["sandbox_mode"],
|
||||
policy["allow_network"],
|
||||
)
|
||||
return policy
|
||||
|
||||
|
||||
__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"]
|
||||
221
comfy/isolation/manifest_loader.py
Normal file
221
comfy/isolation/manifest_loader.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import folder_paths
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ImportError:
|
||||
import tomli as tomllib # type: ignore[no-redef]
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_SUBDIR = "cache"
|
||||
CACHE_KEY_FILE = "cache_key"
|
||||
CACHE_DATA_FILE = "node_info.json"
|
||||
CACHE_KEY_LENGTH = 16
|
||||
_NESTED_SCAN_ROOT = "packages"
|
||||
_IGNORED_MANIFEST_DIRS = {".git", ".venv", "__pycache__"}
|
||||
|
||||
|
||||
def _read_manifest(manifest_path: Path) -> dict[str, Any] | None:
|
||||
try:
|
||||
with manifest_path.open("rb") as f:
|
||||
data = tomllib.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _is_isolation_manifest(data: dict[str, Any]) -> bool:
|
||||
return (
|
||||
"tool" in data
|
||||
and "comfy" in data["tool"]
|
||||
and "isolation" in data["tool"]["comfy"]
|
||||
)
|
||||
|
||||
|
||||
def _discover_nested_manifests(entry: Path) -> List[Tuple[Path, Path]]:
|
||||
packages_root = entry / _NESTED_SCAN_ROOT
|
||||
if not packages_root.exists() or not packages_root.is_dir():
|
||||
return []
|
||||
|
||||
nested: List[Tuple[Path, Path]] = []
|
||||
for manifest in sorted(packages_root.rglob("pyproject.toml")):
|
||||
node_dir = manifest.parent
|
||||
if any(part in _IGNORED_MANIFEST_DIRS for part in node_dir.parts):
|
||||
continue
|
||||
|
||||
data = _read_manifest(manifest)
|
||||
if not data or not _is_isolation_manifest(data):
|
||||
continue
|
||||
|
||||
isolation = data["tool"]["comfy"]["isolation"]
|
||||
if isolation.get("standalone") is True:
|
||||
nested.append((node_dir, manifest))
|
||||
|
||||
return nested
|
||||
|
||||
|
||||
def find_manifest_directories() -> List[Tuple[Path, Path]]:
|
||||
"""Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation]."""
|
||||
manifest_dirs: List[Tuple[Path, Path]] = []
|
||||
|
||||
# Standard custom_nodes paths
|
||||
for base_path in folder_paths.get_folder_paths("custom_nodes"):
|
||||
base = Path(base_path)
|
||||
if not base.exists() or not base.is_dir():
|
||||
continue
|
||||
|
||||
for entry in base.iterdir():
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
# Look for pyproject.toml
|
||||
manifest = entry / "pyproject.toml"
|
||||
if not manifest.exists():
|
||||
continue
|
||||
|
||||
data = _read_manifest(manifest)
|
||||
if not data or not _is_isolation_manifest(data):
|
||||
continue
|
||||
|
||||
manifest_dirs.append((entry, manifest))
|
||||
manifest_dirs.extend(_discover_nested_manifests(entry))
|
||||
|
||||
return manifest_dirs
|
||||
|
||||
|
||||
def compute_cache_key(node_dir: Path, manifest_path: Path) -> str:
|
||||
"""Hash manifest + .py mtimes + Python version + PyIsolate version."""
|
||||
hasher = hashlib.sha256()
|
||||
|
||||
try:
|
||||
# Hashing the manifest content ensures config changes invalidate cache
|
||||
hasher.update(manifest_path.read_bytes())
|
||||
except OSError:
|
||||
hasher.update(b"__manifest_read_error__")
|
||||
|
||||
try:
|
||||
py_files = sorted(node_dir.rglob("*.py"))
|
||||
for py_file in py_files:
|
||||
rel_path = py_file.relative_to(node_dir)
|
||||
if "__pycache__" in str(rel_path) or ".venv" in str(rel_path):
|
||||
continue
|
||||
hasher.update(str(rel_path).encode("utf-8"))
|
||||
try:
|
||||
hasher.update(str(py_file.stat().st_mtime).encode("utf-8"))
|
||||
except OSError:
|
||||
hasher.update(b"__file_stat_error__")
|
||||
except OSError:
|
||||
hasher.update(b"__dir_scan_error__")
|
||||
|
||||
hasher.update(sys.version.encode("utf-8"))
|
||||
|
||||
try:
|
||||
import pyisolate
|
||||
|
||||
hasher.update(pyisolate.__version__.encode("utf-8"))
|
||||
except (ImportError, AttributeError):
|
||||
hasher.update(b"__pyisolate_unknown__")
|
||||
|
||||
return hasher.hexdigest()[:CACHE_KEY_LENGTH]
|
||||
|
||||
|
||||
def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]:
|
||||
"""Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/."""
|
||||
cache_dir = venv_root / node_dir.name / CACHE_SUBDIR
|
||||
return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE)
|
||||
|
||||
|
||||
def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool:
|
||||
"""Return True only if stored cache key matches current computed key."""
|
||||
try:
|
||||
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
if not cache_key_file.exists() or not cache_data_file.exists():
|
||||
return False
|
||||
current_key = compute_cache_key(node_dir, manifest_path)
|
||||
stored_key = cache_key_file.read_text(encoding="utf-8").strip()
|
||||
return current_key == stored_key
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]:
|
||||
"""Load node metadata from cache, return None on any error."""
|
||||
try:
|
||||
_, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
if not cache_data_file.exists():
|
||||
return None
|
||||
data = json.loads(cache_data_file.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
return data
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def save_to_cache(
|
||||
node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path
|
||||
) -> None:
|
||||
"""Save node metadata and cache key atomically."""
|
||||
try:
|
||||
cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root)
|
||||
cache_dir = cache_key_file.parent
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_key = compute_cache_key(node_dir, manifest_path)
|
||||
|
||||
# Atomic write: data
|
||||
tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f:
|
||||
json.dump(node_data, f, indent=2)
|
||||
os.replace(tmp_data_path, cache_data_file)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_data_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
# Atomic write: key
|
||||
tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f:
|
||||
f.write(cache_key)
|
||||
os.replace(tmp_key_path, cache_key_file)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_key_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LOG_PREFIX",
|
||||
"find_manifest_directories",
|
||||
"compute_cache_key",
|
||||
"get_cache_path",
|
||||
"is_cache_valid",
|
||||
"load_from_cache",
|
||||
"save_to_cache",
|
||||
]
|
||||
888
comfy/isolation/model_patcher_proxy.py
Normal file
888
comfy/isolation/model_patcher_proxy.py
Normal file
@@ -0,0 +1,888 @@
|
||||
# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access
|
||||
# RPC proxy for ModelPatcher (parent process)
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, List, Set, Dict, Callable
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
)
|
||||
from comfy.isolation.model_patcher_proxy_registry import (
|
||||
ModelPatcherRegistry,
|
||||
AutoPatcherEjector,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
|
||||
_registry_class = ModelPatcherRegistry
|
||||
__module__ = "comfy.model_patcher"
|
||||
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
|
||||
|
||||
def _spawn_related_proxy(self, instance_id: str) -> "ModelPatcherProxy":
|
||||
proxy = ModelPatcherProxy(
|
||||
instance_id,
|
||||
self._registry,
|
||||
manage_lifecycle=not IS_CHILD_PROCESS,
|
||||
)
|
||||
if getattr(self, "_rpc_caller", None) is not None:
|
||||
proxy._rpc_caller = self._rpc_caller
|
||||
return proxy
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is not None:
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
self._registry_class, self._registry_class.get_remote_id()
|
||||
)
|
||||
else:
|
||||
self._rpc_caller = self._registry
|
||||
return self._rpc_caller
|
||||
|
||||
def get_all_callbacks(self, call_type: str = None) -> Any:
|
||||
return self._call_rpc("get_all_callbacks", call_type)
|
||||
|
||||
def get_all_wrappers(self, wrapper_type: str = None) -> Any:
|
||||
return self._call_rpc("get_all_wrappers", wrapper_type)
|
||||
|
||||
def _load_list(self, *args, **kwargs) -> Any:
|
||||
return self._call_rpc("load_list_internal", *args, **kwargs)
|
||||
|
||||
def prepare_hook_patches_current_keyframe(
|
||||
self, t: Any, hook_group: Any, model_options: Any
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"prepare_hook_patches_current_keyframe", t, hook_group, model_options
|
||||
)
|
||||
|
||||
def add_hook_patches(
|
||||
self,
|
||||
hook: Any,
|
||||
patches: Any,
|
||||
strength_patch: float = 1.0,
|
||||
strength_model: float = 1.0,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"add_hook_patches", hook, patches, strength_patch, strength_model
|
||||
)
|
||||
|
||||
def clear_cached_hook_weights(self) -> None:
|
||||
self._call_rpc("clear_cached_hook_weights")
|
||||
|
||||
def get_combined_hook_patches(self, hooks: Any) -> Any:
|
||||
return self._call_rpc("get_combined_hook_patches", hooks)
|
||||
|
||||
def get_additional_models_with_key(self, key: str) -> Any:
|
||||
return self._call_rpc("get_additional_models_with_key", key)
|
||||
|
||||
@property
|
||||
def object_patches(self) -> Any:
|
||||
return self._call_rpc("get_object_patches")
|
||||
|
||||
@property
|
||||
def patches(self) -> Any:
|
||||
res = self._call_rpc("get_patches")
|
||||
if isinstance(res, dict):
|
||||
new_res = {}
|
||||
for k, v in res.items():
|
||||
new_list = []
|
||||
for item in v:
|
||||
if isinstance(item, list):
|
||||
new_list.append(tuple(item))
|
||||
else:
|
||||
new_list.append(item)
|
||||
new_res[k] = new_list
|
||||
return new_res
|
||||
return res
|
||||
|
||||
@property
|
||||
def pinned(self) -> Set:
|
||||
val = self._call_rpc("get_patcher_attr", "pinned")
|
||||
return set(val) if val is not None else set()
|
||||
|
||||
@property
|
||||
def hook_patches(self) -> Dict:
|
||||
val = self._call_rpc("get_patcher_attr", "hook_patches")
|
||||
if val is None:
|
||||
return {}
|
||||
try:
|
||||
from comfy.hooks import _HookRef
|
||||
import json
|
||||
|
||||
new_val = {}
|
||||
for k, v in val.items():
|
||||
if isinstance(k, str):
|
||||
if k.startswith("PYISOLATE_HOOKREF:"):
|
||||
ref_id = k.split(":", 1)[1]
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = ref_id
|
||||
new_val[h] = v
|
||||
elif k.startswith("__pyisolate_key__"):
|
||||
try:
|
||||
json_str = k[len("__pyisolate_key__") :]
|
||||
data = json.loads(json_str)
|
||||
ref_id = None
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if (
|
||||
isinstance(item, list)
|
||||
and len(item) == 2
|
||||
and item[0] == "id"
|
||||
):
|
||||
ref_id = item[1]
|
||||
break
|
||||
if ref_id:
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = ref_id
|
||||
new_val[h] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
except Exception:
|
||||
new_val[k] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
else:
|
||||
new_val[k] = v
|
||||
return new_val
|
||||
except ImportError:
|
||||
return val
|
||||
|
||||
def set_hook_mode(self, hook_mode: Any) -> None:
|
||||
self._call_rpc("set_hook_mode", hook_mode)
|
||||
|
||||
def register_all_hook_patches(
|
||||
self,
|
||||
hooks: Any,
|
||||
target_dict: Any,
|
||||
model_options: Any = None,
|
||||
registered: Any = None,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"register_all_hook_patches", hooks, target_dict, model_options, registered
|
||||
)
|
||||
|
||||
def is_clone(self, other: Any) -> bool:
|
||||
if isinstance(other, ModelPatcherProxy):
|
||||
return self._call_rpc("is_clone_by_id", other._instance_id)
|
||||
return False
|
||||
|
||||
def clone(self) -> ModelPatcherProxy:
|
||||
new_id = self._call_rpc("clone")
|
||||
return self._spawn_related_proxy(new_id)
|
||||
|
||||
def clone_has_same_weights(self, clone: Any) -> bool:
|
||||
if isinstance(clone, ModelPatcherProxy):
|
||||
return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id)
|
||||
if not IS_CHILD_PROCESS:
|
||||
return self._call_rpc("is_clone", clone)
|
||||
return False
|
||||
|
||||
def get_model_object(self, name: str) -> Any:
|
||||
return self._call_rpc("get_model_object", name)
|
||||
|
||||
@property
|
||||
def model_options(self) -> dict:
|
||||
data = self._call_rpc("get_model_options")
|
||||
import json
|
||||
|
||||
def _decode_keys(obj):
|
||||
if isinstance(obj, dict):
|
||||
new_d = {}
|
||||
for k, v in obj.items():
|
||||
if isinstance(k, str) and k.startswith("__pyisolate_key__"):
|
||||
try:
|
||||
json_str = k[17:]
|
||||
val = json.loads(json_str)
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
new_d[val] = _decode_keys(v)
|
||||
except:
|
||||
new_d[k] = _decode_keys(v)
|
||||
else:
|
||||
new_d[k] = _decode_keys(v)
|
||||
return new_d
|
||||
if isinstance(obj, list):
|
||||
return [_decode_keys(x) for x in obj]
|
||||
return obj
|
||||
|
||||
return _decode_keys(data)
|
||||
|
||||
@model_options.setter
|
||||
def model_options(self, value: dict) -> None:
|
||||
self._call_rpc("set_model_options", value)
|
||||
|
||||
def apply_hooks(self, hooks: Any) -> Any:
|
||||
return self._call_rpc("apply_hooks", hooks)
|
||||
|
||||
def prepare_state(self, timestep: Any) -> Any:
|
||||
return self._call_rpc("prepare_state", timestep)
|
||||
|
||||
def restore_hook_patches(self) -> None:
|
||||
self._call_rpc("restore_hook_patches")
|
||||
|
||||
def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None:
|
||||
self._call_rpc("unpatch_hooks", whitelist_keys_set)
|
||||
|
||||
def model_patches_to(self, device: Any) -> Any:
|
||||
return self._call_rpc("model_patches_to", device)
|
||||
|
||||
def partially_load(
|
||||
self, device: Any, extra_memory: Any, force_patch_weights: bool = False
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"partially_load", device, extra_memory, force_patch_weights
|
||||
)
|
||||
|
||||
def partially_unload(
|
||||
self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False
|
||||
) -> int:
|
||||
return self._call_rpc(
|
||||
"partially_unload", device_to, memory_to_free, force_patch_weights
|
||||
)
|
||||
|
||||
def load(
|
||||
self,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
force_patch_weights: bool = False,
|
||||
full_load: bool = False,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"load", device_to, lowvram_model_memory, force_patch_weights, full_load
|
||||
)
|
||||
|
||||
def patch_model(
|
||||
self,
|
||||
device_to: Any = None,
|
||||
lowvram_model_memory: int = 0,
|
||||
load_weights: bool = True,
|
||||
force_patch_weights: bool = False,
|
||||
) -> Any:
|
||||
self._call_rpc(
|
||||
"patch_model",
|
||||
device_to,
|
||||
lowvram_model_memory,
|
||||
load_weights,
|
||||
force_patch_weights,
|
||||
)
|
||||
return self
|
||||
|
||||
def unpatch_model(
|
||||
self, device_to: Any = None, unpatch_weights: bool = True
|
||||
) -> None:
|
||||
self._call_rpc("unpatch_model", device_to, unpatch_weights)
|
||||
|
||||
def detach(self, unpatch_all: bool = True) -> Any:
|
||||
self._call_rpc("detach", unpatch_all)
|
||||
return self.model
|
||||
|
||||
def _cpu_tensor_bytes(self, obj: Any) -> int:
|
||||
import torch
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if obj.device.type == "cpu":
|
||||
return obj.nbytes
|
||||
return 0
|
||||
if isinstance(obj, dict):
|
||||
return sum(self._cpu_tensor_bytes(v) for v in obj.values())
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return sum(self._cpu_tensor_bytes(v) for v in obj)
|
||||
return 0
|
||||
|
||||
def _ensure_apply_model_headroom(self, required_bytes: int) -> bool:
|
||||
if required_bytes <= 0:
|
||||
return True
|
||||
|
||||
import torch
|
||||
import comfy.model_management as model_management
|
||||
|
||||
target_raw = self.load_device
|
||||
try:
|
||||
if isinstance(target_raw, torch.device):
|
||||
target = target_raw
|
||||
elif isinstance(target_raw, str):
|
||||
target = torch.device(target_raw)
|
||||
elif isinstance(target_raw, int):
|
||||
target = torch.device(f"cuda:{target_raw}")
|
||||
else:
|
||||
target = torch.device(target_raw)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
if target.type != "cuda":
|
||||
return True
|
||||
|
||||
required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES
|
||||
if model_management.get_free_memory(target) >= required:
|
||||
return True
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
model_management.free_memory(required, target, for_dynamic=True)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
# Escalate to non-dynamic unloading before dispatching CUDA transfer.
|
||||
model_management.free_memory(required, target, for_dynamic=False)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
if model_management.get_free_memory(target) < required:
|
||||
model_management.load_models_gpu(
|
||||
[self],
|
||||
minimum_memory_required=required,
|
||||
)
|
||||
|
||||
return model_management.get_free_memory(target) >= required
|
||||
|
||||
def apply_model(self, *args, **kwargs) -> Any:
|
||||
import torch
|
||||
|
||||
def _preferred_device() -> Any:
|
||||
for value in args:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
return None
|
||||
|
||||
def _move_result_to_device(obj: Any, device: Any) -> Any:
|
||||
if device is None:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.to(device) if obj.device != device else obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_move_result_to_device(v, device) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_move_result_to_device(v, device) for v in obj)
|
||||
return obj
|
||||
|
||||
# DynamicVRAM models must keep load/offload decisions in host process.
|
||||
# Child-side CUDA staging here can deadlock before first inference RPC.
|
||||
if self.is_dynamic():
|
||||
out = self._call_rpc("inner_model_apply_model", args, kwargs)
|
||||
return _move_result_to_device(out, _preferred_device())
|
||||
|
||||
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
|
||||
def _to_cuda(obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor) and obj.device.type == "cpu":
|
||||
return obj.to("cuda")
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cuda(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_to_cuda(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_to_cuda(v) for v in obj)
|
||||
return obj
|
||||
|
||||
try:
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
except torch.OutOfMemoryError:
|
||||
self._ensure_apply_model_headroom(required_bytes)
|
||||
args_cuda = _to_cuda(args)
|
||||
kwargs_cuda = _to_cuda(kwargs)
|
||||
|
||||
out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
|
||||
return _move_result_to_device(out, _preferred_device())
|
||||
|
||||
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
keys = self._call_rpc("model_state_dict", filter_prefix)
|
||||
return dict.fromkeys(keys, None)
|
||||
|
||||
def add_patches(self, *args: Any, **kwargs: Any) -> Any:
|
||||
res = self._call_rpc("add_patches", *args, **kwargs)
|
||||
if isinstance(res, list):
|
||||
return [tuple(x) if isinstance(x, list) else x for x in res]
|
||||
return res
|
||||
|
||||
def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any:
|
||||
return self._call_rpc("get_key_patches", filter_prefix)
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
self._call_rpc("patch_weight_to_device", key, device_to, inplace_update)
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
self._call_rpc("pin_weight_to_device", key)
|
||||
|
||||
def unpin_weight(self, key):
|
||||
self._call_rpc("unpin_weight", key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
self._call_rpc("unpin_all_weights")
|
||||
|
||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=None):
|
||||
return self._call_rpc(
|
||||
"calculate_weight", patches, weight, key, intermediate_dtype
|
||||
)
|
||||
|
||||
def inject_model(self) -> None:
|
||||
self._call_rpc("inject_model")
|
||||
|
||||
def eject_model(self) -> None:
|
||||
self._call_rpc("eject_model")
|
||||
|
||||
def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any:
|
||||
return AutoPatcherEjector(
|
||||
self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only
|
||||
)
|
||||
|
||||
@property
|
||||
def is_injected(self) -> bool:
|
||||
return self._call_rpc("get_is_injected")
|
||||
|
||||
@property
|
||||
def skip_injection(self) -> bool:
|
||||
return self._call_rpc("get_skip_injection")
|
||||
|
||||
@skip_injection.setter
|
||||
def skip_injection(self, value: bool) -> None:
|
||||
self._call_rpc("set_skip_injection", value)
|
||||
|
||||
def clean_hooks(self) -> None:
|
||||
self._call_rpc("clean_hooks")
|
||||
|
||||
def pre_run(self) -> None:
|
||||
self._call_rpc("pre_run")
|
||||
|
||||
def cleanup(self) -> None:
|
||||
try:
|
||||
self._call_rpc("cleanup")
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"ModelPatcherProxy cleanup RPC failed for %s",
|
||||
self._instance_id,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
super().cleanup()
|
||||
|
||||
@property
|
||||
def model(self) -> _InnerModelProxy:
|
||||
return _InnerModelProxy(self)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
_whitelisted_attrs = {
|
||||
"hook_patches_backup",
|
||||
"hook_backup",
|
||||
"cached_hook_patches",
|
||||
"current_hooks",
|
||||
"forced_hooks",
|
||||
"is_clip",
|
||||
"patches_uuid",
|
||||
"pinned",
|
||||
"attachments",
|
||||
"additional_models",
|
||||
"injections",
|
||||
"hook_patches",
|
||||
"model_lowvram",
|
||||
"model_loaded_weight_memory",
|
||||
"backup",
|
||||
"object_patches_backup",
|
||||
"weight_wrapper_patches",
|
||||
"weight_inplace_update",
|
||||
"force_cast_weights",
|
||||
}
|
||||
if name in _whitelisted_attrs:
|
||||
return self._call_rpc("get_patcher_attr", name)
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def load_lora(
|
||||
self,
|
||||
lora_path: str,
|
||||
strength_model: float,
|
||||
clip: Optional[Any] = None,
|
||||
strength_clip: float = 1.0,
|
||||
) -> tuple:
|
||||
clip_id = None
|
||||
if clip is not None:
|
||||
clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None))
|
||||
result = self._call_rpc(
|
||||
"load_lora", lora_path, strength_model, clip_id, strength_clip
|
||||
)
|
||||
new_model = None
|
||||
if result.get("model_id"):
|
||||
new_model = self._spawn_related_proxy(result["model_id"])
|
||||
new_clip = None
|
||||
if result.get("clip_id"):
|
||||
from comfy.isolation.clip_proxy import CLIPProxy
|
||||
|
||||
new_clip = CLIPProxy(result["clip_id"])
|
||||
return (new_model, new_clip)
|
||||
|
||||
@property
|
||||
def load_device(self) -> Any:
|
||||
return self._call_rpc("get_load_device")
|
||||
|
||||
@property
|
||||
def offload_device(self) -> Any:
|
||||
return self._call_rpc("get_offload_device")
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self.load_device
|
||||
|
||||
def current_loaded_device(self) -> Any:
|
||||
return self._call_rpc("current_loaded_device")
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self._call_rpc("get_size")
|
||||
|
||||
def model_size(self) -> Any:
|
||||
return self._call_rpc("model_size")
|
||||
|
||||
def loaded_size(self) -> Any:
|
||||
return self._call_rpc("loaded_size")
|
||||
|
||||
def get_ram_usage(self) -> int:
|
||||
return self._call_rpc("get_ram_usage")
|
||||
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
return self._call_rpc("lowvram_patch_counter")
|
||||
|
||||
def memory_required(self, input_shape: Any) -> Any:
|
||||
return self._call_rpc("memory_required", input_shape)
|
||||
|
||||
def get_operation_state(self) -> Dict[str, Any]:
|
||||
state = self._call_rpc("get_operation_state")
|
||||
return state if isinstance(state, dict) else {}
|
||||
|
||||
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
|
||||
return bool(self._call_rpc("wait_for_idle", timeout_ms))
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return bool(self._call_rpc("is_dynamic"))
|
||||
|
||||
def get_free_memory(self, device: Any) -> Any:
|
||||
return self._call_rpc("get_free_memory", device)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload: int) -> Any:
|
||||
return self._call_rpc("partially_unload_ram", ram_to_unload)
|
||||
|
||||
def model_dtype(self) -> Any:
|
||||
res = self._call_rpc("model_dtype")
|
||||
if isinstance(res, str) and res.startswith("torch."):
|
||||
try:
|
||||
import torch
|
||||
|
||||
attr = res.split(".")[-1]
|
||||
if hasattr(torch, attr):
|
||||
return getattr(torch, attr)
|
||||
except ImportError:
|
||||
pass
|
||||
return res
|
||||
|
||||
@property
|
||||
def hook_mode(self) -> Any:
|
||||
return self._call_rpc("get_hook_mode")
|
||||
|
||||
@hook_mode.setter
|
||||
def hook_mode(self, value: Any) -> None:
|
||||
self._call_rpc("set_hook_mode", value)
|
||||
|
||||
def set_model_sampler_cfg_function(
|
||||
self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_cfg_function",
|
||||
sampler_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_post_cfg_function(
|
||||
self, post_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_post_cfg_function",
|
||||
post_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_pre_cfg_function(
|
||||
self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_sampler_pre_cfg_function",
|
||||
pre_cfg_function,
|
||||
disable_cfg1_optimization,
|
||||
)
|
||||
|
||||
def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None:
|
||||
self._call_rpc("set_model_sampler_calc_cond_batch_function", fn)
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None:
|
||||
self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function)
|
||||
|
||||
def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None:
|
||||
self._call_rpc("set_model_denoise_mask_function", denoise_mask_function)
|
||||
|
||||
def set_model_patch(self, patch: Any, name: str) -> None:
|
||||
self._call_rpc("set_model_patch", patch, name)
|
||||
|
||||
def set_model_patch_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
name: str,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self._call_rpc(
|
||||
"set_model_patch_replace",
|
||||
patch,
|
||||
name,
|
||||
block_name,
|
||||
number,
|
||||
transformer_index,
|
||||
)
|
||||
|
||||
def set_model_attn1_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self.set_model_patch_replace(
|
||||
patch, "attn1", block_name, number, transformer_index
|
||||
)
|
||||
|
||||
def set_model_attn2_replace(
|
||||
self,
|
||||
patch: Any,
|
||||
block_name: str,
|
||||
number: int,
|
||||
transformer_index: Optional[int] = None,
|
||||
) -> None:
|
||||
self.set_model_patch_replace(
|
||||
patch, "attn2", block_name, number, transformer_index
|
||||
)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
|
||||
def set_model_attn2_output_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def set_model_input_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "input_block_patch")
|
||||
|
||||
def set_model_input_block_patch_after_skip(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||
|
||||
def set_model_output_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "output_block_patch")
|
||||
|
||||
def set_model_emb_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "emb_patch")
|
||||
|
||||
def set_model_forward_timestep_embed_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||
|
||||
def set_model_double_block_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "double_block")
|
||||
|
||||
def set_model_post_input_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_rope_options(
|
||||
self,
|
||||
scale_x=1.0,
|
||||
shift_x=0.0,
|
||||
scale_y=1.0,
|
||||
shift_y=0.0,
|
||||
scale_t=1.0,
|
||||
shift_t=0.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
options = {
|
||||
"scale_x": scale_x,
|
||||
"shift_x": shift_x,
|
||||
"scale_y": scale_y,
|
||||
"shift_y": shift_y,
|
||||
"scale_t": scale_t,
|
||||
"shift_t": shift_t,
|
||||
}
|
||||
options.update(kwargs)
|
||||
self._call_rpc("set_model_rope_options", options)
|
||||
|
||||
def set_model_compute_dtype(self, dtype: Any) -> None:
|
||||
self._call_rpc("set_model_compute_dtype", dtype)
|
||||
|
||||
def add_object_patch(self, name: str, obj: Any) -> None:
|
||||
self._call_rpc("add_object_patch", name, obj)
|
||||
|
||||
def add_weight_wrapper(self, name: str, function: Any) -> None:
|
||||
self._call_rpc("add_weight_wrapper", name, function)
|
||||
|
||||
def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None:
|
||||
self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn)
|
||||
|
||||
def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None:
|
||||
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
||||
|
||||
def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None:
|
||||
self._call_rpc("remove_wrappers_with_key", wrapper_type, key)
|
||||
|
||||
@property
|
||||
def wrappers(self) -> Any:
|
||||
return self._call_rpc("get_wrappers")
|
||||
|
||||
def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None:
|
||||
self._call_rpc("add_callback_with_key", call_type, key, callback)
|
||||
|
||||
def add_callback(self, call_type: str, callback: Any) -> None:
|
||||
self.add_callback_with_key(call_type, None, callback)
|
||||
|
||||
def remove_callbacks_with_key(self, call_type: str, key: str) -> None:
|
||||
self._call_rpc("remove_callbacks_with_key", call_type, key)
|
||||
|
||||
@property
|
||||
def callbacks(self) -> Any:
|
||||
return self._call_rpc("get_callbacks")
|
||||
|
||||
def set_attachments(self, key: str, attachment: Any) -> None:
|
||||
self._call_rpc("set_attachments", key, attachment)
|
||||
|
||||
def get_attachment(self, key: str) -> Any:
|
||||
return self._call_rpc("get_attachment", key)
|
||||
|
||||
def remove_attachments(self, key: str) -> None:
|
||||
self._call_rpc("remove_attachments", key)
|
||||
|
||||
def set_injections(self, key: str, injections: Any) -> None:
|
||||
self._call_rpc("set_injections", key, injections)
|
||||
|
||||
def get_injections(self, key: str) -> Any:
|
||||
return self._call_rpc("get_injections", key)
|
||||
|
||||
def remove_injections(self, key: str) -> None:
|
||||
self._call_rpc("remove_injections", key)
|
||||
|
||||
def set_additional_models(self, key: str, models: Any) -> None:
|
||||
ids = [m._instance_id for m in models]
|
||||
self._call_rpc("set_additional_models", key, ids)
|
||||
|
||||
def remove_additional_models(self, key: str) -> None:
|
||||
self._call_rpc("remove_additional_models", key)
|
||||
|
||||
def get_nested_additional_models(self) -> Any:
|
||||
return self._call_rpc("get_nested_additional_models")
|
||||
|
||||
def get_additional_models(self) -> List[ModelPatcherProxy]:
|
||||
ids = self._call_rpc("get_additional_models")
|
||||
return [self._spawn_related_proxy(mid) for mid in ids]
|
||||
|
||||
def model_patches_models(self) -> Any:
|
||||
return self._call_rpc("model_patches_models")
|
||||
|
||||
@property
|
||||
def parent(self) -> Any:
|
||||
return self._call_rpc("get_parent")
|
||||
|
||||
def model_mmap_residency(self, free: bool = False) -> tuple:
|
||||
result = self._call_rpc("model_mmap_residency", free)
|
||||
if isinstance(result, list):
|
||||
return tuple(result)
|
||||
return result
|
||||
|
||||
def pinned_memory_size(self) -> int:
|
||||
return self._call_rpc("pinned_memory_size")
|
||||
|
||||
def get_non_dynamic_delegate(self) -> ModelPatcherProxy:
|
||||
new_id = self._call_rpc("get_non_dynamic_delegate")
|
||||
return self._spawn_related_proxy(new_id)
|
||||
|
||||
def disable_model_cfg1_optimization(self) -> None:
|
||||
self._call_rpc("disable_model_cfg1_optimization")
|
||||
|
||||
def set_model_noise_refiner_patch(self, patch: Any) -> None:
|
||||
self.set_model_patch(patch, "noise_refiner")
|
||||
|
||||
|
||||
class _InnerModelProxy:
|
||||
def __init__(self, parent: ModelPatcherProxy):
|
||||
self._parent = parent
|
||||
self._model_sampling = None
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(name)
|
||||
if name == "model_config":
|
||||
from types import SimpleNamespace
|
||||
|
||||
data = self._parent._call_rpc("get_inner_model_attr", name)
|
||||
if isinstance(data, dict):
|
||||
return SimpleNamespace(**data)
|
||||
return data
|
||||
if name in (
|
||||
"latent_format",
|
||||
"model_type",
|
||||
"current_weight_patches_uuid",
|
||||
):
|
||||
return self._parent._call_rpc("get_inner_model_attr", name)
|
||||
if name == "load_device":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "load_device")
|
||||
if name == "device":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "device")
|
||||
if name == "current_patcher":
|
||||
proxy = ModelPatcherProxy(
|
||||
self._parent._instance_id,
|
||||
self._parent._registry,
|
||||
manage_lifecycle=False,
|
||||
)
|
||||
if getattr(self._parent, "_rpc_caller", None) is not None:
|
||||
proxy._rpc_caller = self._parent._rpc_caller
|
||||
return proxy
|
||||
if name == "model_sampling":
|
||||
if self._model_sampling is None:
|
||||
self._model_sampling = self._parent._call_rpc(
|
||||
"get_model_object", "model_sampling"
|
||||
)
|
||||
return self._model_sampling
|
||||
if name == "extra_conds_shapes":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds_shapes", a, k
|
||||
)
|
||||
if name == "extra_conds":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_extra_conds", a, k
|
||||
)
|
||||
if name == "memory_required":
|
||||
return lambda *a, **k: self._parent._call_rpc(
|
||||
"inner_model_memory_required", a, k
|
||||
)
|
||||
if name == "apply_model":
|
||||
# Delegate to parent's method to get the CPU->CUDA optimization
|
||||
return self._parent.apply_model
|
||||
if name == "process_latent_in":
|
||||
return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k)
|
||||
if name == "process_latent_out":
|
||||
return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k)
|
||||
if name == "scale_latent_inpaint":
|
||||
return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k)
|
||||
if name == "diffusion_model":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "diffusion_model")
|
||||
raise AttributeError(f"'{name}' not supported on isolated InnerModel")
|
||||
1311
comfy/isolation/model_patcher_proxy_registry.py
Normal file
1311
comfy/isolation/model_patcher_proxy_registry.py
Normal file
File diff suppressed because it is too large
Load Diff
156
comfy/isolation/model_patcher_proxy_utils.py
Normal file
156
comfy/isolation/model_patcher_proxy_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access
|
||||
# Isolation utilities and serializers for ModelPatcherProxy
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
isolation_active = args.use_process_isolation or is_child
|
||||
|
||||
if not isolation_active:
|
||||
return model_patcher
|
||||
if is_child:
|
||||
return model_patcher
|
||||
if isinstance(model_patcher, ModelPatcherProxy):
|
||||
return model_patcher
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
model_id = registry.register(model_patcher)
|
||||
logger.debug(f"Isolated ModelPatcher: {model_id}")
|
||||
return ModelPatcherProxy(model_id, registry, manage_lifecycle=True)
|
||||
|
||||
|
||||
def register_hooks_serializers(registry=None):
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
import comfy.hooks
|
||||
|
||||
if registry is None:
|
||||
registry = SerializerRegistry.get_instance()
|
||||
|
||||
def serialize_enum(obj):
|
||||
return {"__enum__": f"{type(obj).__name__}.{obj.name}"}
|
||||
|
||||
def deserialize_enum(data):
|
||||
cls_name, val_name = data["__enum__"].split(".")
|
||||
cls = getattr(comfy.hooks, cls_name)
|
||||
return cls[val_name]
|
||||
|
||||
registry.register("EnumHookType", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumHookScope", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumHookMode", serialize_enum, deserialize_enum)
|
||||
registry.register("EnumWeightTarget", serialize_enum, deserialize_enum)
|
||||
|
||||
def serialize_hook_group(obj):
|
||||
return {"__type__": "HookGroup", "hooks": obj.hooks}
|
||||
|
||||
def deserialize_hook_group(data):
|
||||
hg = comfy.hooks.HookGroup()
|
||||
for h in data["hooks"]:
|
||||
hg.add(h)
|
||||
return hg
|
||||
|
||||
registry.register("HookGroup", serialize_hook_group, deserialize_hook_group)
|
||||
|
||||
def serialize_dict_state(obj):
|
||||
d = obj.__dict__.copy()
|
||||
d["__type__"] = type(obj).__name__
|
||||
if "custom_should_register" in d:
|
||||
del d["custom_should_register"]
|
||||
return d
|
||||
|
||||
def deserialize_dict_state_generic(cls):
|
||||
def _deserialize(data):
|
||||
h = cls()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
return _deserialize
|
||||
|
||||
def deserialize_hook_keyframe(data):
|
||||
h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0))
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe)
|
||||
|
||||
def deserialize_hook_keyframe_group(data):
|
||||
h = comfy.hooks.HookKeyframeGroup()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register(
|
||||
"HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group
|
||||
)
|
||||
|
||||
def deserialize_hook(data):
|
||||
h = comfy.hooks.Hook()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("Hook", serialize_dict_state, deserialize_hook)
|
||||
|
||||
def deserialize_weight_hook(data):
|
||||
h = comfy.hooks.WeightHook()
|
||||
h.__dict__.update(data)
|
||||
return h
|
||||
|
||||
registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook)
|
||||
|
||||
def serialize_set(obj):
|
||||
return {"__set__": list(obj)}
|
||||
|
||||
def deserialize_set(data):
|
||||
return set(data["__set__"])
|
||||
|
||||
registry.register("set", serialize_set, deserialize_set)
|
||||
|
||||
try:
|
||||
from comfy.weight_adapter.lora import LoRAAdapter
|
||||
|
||||
def serialize_lora(obj):
|
||||
return {"weights": {}, "loaded_keys": list(obj.loaded_keys)}
|
||||
|
||||
def deserialize_lora(data):
|
||||
return LoRAAdapter(set(data["loaded_keys"]), data["weights"])
|
||||
|
||||
registry.register("LoRAAdapter", serialize_lora, deserialize_lora)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from comfy.hooks import _HookRef
|
||||
import uuid
|
||||
|
||||
def serialize_hook_ref(obj):
|
||||
return {
|
||||
"__hook_ref__": True,
|
||||
"id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())),
|
||||
}
|
||||
|
||||
def deserialize_hook_ref(data):
|
||||
h = _HookRef()
|
||||
h._pyisolate_id = data.get("id", str(uuid.uuid4()))
|
||||
return h
|
||||
|
||||
registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register _HookRef: {e}")
|
||||
|
||||
|
||||
try:
|
||||
register_hooks_serializers()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize hook serializers: {e}")
|
||||
360
comfy/isolation/model_sampling_proxy.py
Normal file
360
comfy/isolation/model_sampling_proxy.py
Normal file
@@ -0,0 +1,360 @@
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
get_thread_loop,
|
||||
run_coro_in_new_loop,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _describe_value(obj: Any) -> str:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
torch = None
|
||||
try:
|
||||
if torch is not None and isinstance(obj, torch.Tensor):
|
||||
return (
|
||||
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
|
||||
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return "%s(id=%s)" % (type(obj).__name__, id(obj))
|
||||
|
||||
|
||||
def _prefer_device(*tensors: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return None
|
||||
for t in tensors:
|
||||
if isinstance(t, torch.Tensor) and t.is_cuda:
|
||||
return t.device
|
||||
for t in tensors:
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.device
|
||||
return None
|
||||
|
||||
|
||||
def _to_device(obj: Any, device: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if device is None:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
if obj.device != device:
|
||||
return obj.to(device)
|
||||
return obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_device(x, device) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_device(v, device) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
def _to_cpu_for_rpc(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
t = obj.detach() if obj.requires_grad else obj
|
||||
if t.is_cuda:
|
||||
return t.to("cpu")
|
||||
return t
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_cpu_for_rpc(x) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class ModelSamplingRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "modelsampling"
|
||||
|
||||
async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.calculate_input(sigma, noise))
|
||||
|
||||
async def calculate_denoised(
|
||||
self, instance_id: str, sigma: Any, model_output: Any, model_input: Any
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(
|
||||
sampling.calculate_denoised(sigma, model_output, model_input)
|
||||
)
|
||||
|
||||
async def noise_scaling(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
noise: Any,
|
||||
latent_image: Any,
|
||||
max_denoise: bool = False,
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(
|
||||
sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise)
|
||||
)
|
||||
|
||||
async def inverse_noise_scaling(
|
||||
self, instance_id: str, sigma: Any, latent: Any
|
||||
) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent))
|
||||
|
||||
async def timestep(self, instance_id: str, sigma: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.timestep(sigma)
|
||||
|
||||
async def sigma(self, instance_id: str, timestep: Any) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.sigma(timestep)
|
||||
|
||||
async def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return sampling.percent_to_sigma(percent)
|
||||
|
||||
async def get_sigma_min(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_min)
|
||||
|
||||
async def get_sigma_max(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_max)
|
||||
|
||||
async def get_sigma_data(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigma_data)
|
||||
|
||||
async def get_sigmas(self, instance_id: str) -> Any:
|
||||
sampling = self._get_instance(instance_id)
|
||||
return detach_if_grad(sampling.sigmas)
|
||||
|
||||
async def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
||||
sampling = self._get_instance(instance_id)
|
||||
sampling.set_sigmas(sigmas)
|
||||
|
||||
|
||||
class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
_registry_class = ModelSamplingRegistry
|
||||
__module__ = "comfy.isolation.model_sampling_proxy"
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is not None:
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id()
|
||||
)
|
||||
else:
|
||||
registry = ModelSamplingRegistry()
|
||||
|
||||
class _LocalCaller:
|
||||
def calculate_input(
|
||||
self, instance_id: str, sigma: Any, noise: Any
|
||||
) -> Any:
|
||||
return registry.calculate_input(instance_id, sigma, noise)
|
||||
|
||||
def calculate_denoised(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
model_output: Any,
|
||||
model_input: Any,
|
||||
) -> Any:
|
||||
return registry.calculate_denoised(
|
||||
instance_id, sigma, model_output, model_input
|
||||
)
|
||||
|
||||
def noise_scaling(
|
||||
self,
|
||||
instance_id: str,
|
||||
sigma: Any,
|
||||
noise: Any,
|
||||
latent_image: Any,
|
||||
max_denoise: bool = False,
|
||||
) -> Any:
|
||||
return registry.noise_scaling(
|
||||
instance_id, sigma, noise, latent_image, max_denoise
|
||||
)
|
||||
|
||||
def inverse_noise_scaling(
|
||||
self, instance_id: str, sigma: Any, latent: Any
|
||||
) -> Any:
|
||||
return registry.inverse_noise_scaling(
|
||||
instance_id, sigma, latent
|
||||
)
|
||||
|
||||
def timestep(self, instance_id: str, sigma: Any) -> Any:
|
||||
return registry.timestep(instance_id, sigma)
|
||||
|
||||
def sigma(self, instance_id: str, timestep: Any) -> Any:
|
||||
return registry.sigma(instance_id, timestep)
|
||||
|
||||
def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
||||
return registry.percent_to_sigma(instance_id, percent)
|
||||
|
||||
def get_sigma_min(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_min(instance_id)
|
||||
|
||||
def get_sigma_max(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_max(instance_id)
|
||||
|
||||
def get_sigma_data(self, instance_id: str) -> Any:
|
||||
return registry.get_sigma_data(instance_id)
|
||||
|
||||
def get_sigmas(self, instance_id: str) -> Any:
|
||||
return registry.get_sigmas(instance_id)
|
||||
|
||||
def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
||||
return registry.set_sigmas(instance_id, sigmas)
|
||||
|
||||
self._rpc_caller = _LocalCaller()
|
||||
return self._rpc_caller
|
||||
|
||||
def _call(self, method_name: str, *args: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
result = method(self._instance_id, *args)
|
||||
timeout_ms = self._rpc_timeout_ms()
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
call_id = "%s:%s:%s:%.6f" % (
|
||||
self._instance_id,
|
||||
method_name,
|
||||
thread_id,
|
||||
start_perf,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
timeout_ms,
|
||||
)
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
out = run_coro_in_new_loop(result)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
out = loop.run_until_complete(result)
|
||||
else:
|
||||
out = result
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
_describe_value(out),
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _rpc_timeout_ms() -> int:
|
||||
raw = os.environ.get(
|
||||
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
|
||||
)
|
||||
try:
|
||||
timeout_ms = int(raw)
|
||||
except ValueError:
|
||||
timeout_ms = 30000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
@property
|
||||
def sigma_min(self) -> Any:
|
||||
return self._call("get_sigma_min")
|
||||
|
||||
@property
|
||||
def sigma_max(self) -> Any:
|
||||
return self._call("get_sigma_max")
|
||||
|
||||
@property
|
||||
def sigma_data(self) -> Any:
|
||||
return self._call("get_sigma_data")
|
||||
|
||||
@property
|
||||
def sigmas(self) -> Any:
|
||||
return self._call("get_sigmas")
|
||||
|
||||
def calculate_input(self, sigma: Any, noise: Any) -> Any:
|
||||
return self._call("calculate_input", sigma, noise)
|
||||
|
||||
def calculate_denoised(
|
||||
self, sigma: Any, model_output: Any, model_input: Any
|
||||
) -> Any:
|
||||
return self._call("calculate_denoised", sigma, model_output, model_input)
|
||||
|
||||
def noise_scaling(
|
||||
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
|
||||
) -> Any:
|
||||
preferred_device = _prefer_device(noise, latent_image)
|
||||
out = self._call(
|
||||
"noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(noise),
|
||||
_to_cpu_for_rpc(latent_image),
|
||||
max_denoise,
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
|
||||
preferred_device = _prefer_device(latent)
|
||||
out = self._call(
|
||||
"inverse_noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(latent),
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def timestep(self, sigma: Any) -> Any:
|
||||
return self._call("timestep", sigma)
|
||||
|
||||
def sigma(self, timestep: Any) -> Any:
|
||||
return self._call("sigma", timestep)
|
||||
|
||||
def percent_to_sigma(self, percent: float) -> Any:
|
||||
return self._call("percent_to_sigma", percent)
|
||||
|
||||
def set_sigmas(self, sigmas: Any) -> None:
|
||||
return self._call("set_sigmas", sigmas)
|
||||
17
comfy/isolation/proxies/__init__.py
Normal file
17
comfy/isolation/proxies/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
get_thread_loop,
|
||||
run_coro_in_new_loop,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"IS_CHILD_PROCESS",
|
||||
"BaseRegistry",
|
||||
"BaseProxy",
|
||||
"get_thread_loop",
|
||||
"run_coro_in_new_loop",
|
||||
"detach_if_grad",
|
||||
]
|
||||
301
comfy/isolation/proxies/base.py
Normal file
301
comfy/isolation/proxies/base.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# pylint: disable=global-statement,import-outside-toplevel,protected-access
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
||||
|
||||
try:
|
||||
from pyisolate import ProxiedSingleton
|
||||
except ImportError:
|
||||
|
||||
class ProxiedSingleton: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
_thread_local = threading.local()
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_thread_loop() -> asyncio.AbstractEventLoop:
|
||||
loop = getattr(_thread_local, "loop", None)
|
||||
if loop is None or loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
_thread_local.loop = loop
|
||||
return loop
|
||||
|
||||
|
||||
def run_coro_in_new_loop(coro: Any) -> Any:
|
||||
result_box: Dict[str, Any] = {}
|
||||
exc_box: Dict[str, BaseException] = {}
|
||||
|
||||
def runner() -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result_box["value"] = loop.run_until_complete(coro)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
exc_box["exc"] = exc
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
t = threading.Thread(target=runner, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
if "exc" in exc_box:
|
||||
raise exc_box["exc"]
|
||||
return result_box.get("value")
|
||||
|
||||
|
||||
def detach_if_grad(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.detach() if obj.requires_grad else obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return type(obj)(detach_if_grad(x) for x in obj)
|
||||
if isinstance(obj, dict):
|
||||
return {k: detach_if_grad(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class BaseRegistry(ProxiedSingleton, Generic[T]):
|
||||
_type_prefix: str = "base"
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object:
|
||||
super().__init__()
|
||||
self._registry: Dict[str, T] = {}
|
||||
self._id_map: Dict[int, str] = {}
|
||||
self._counter = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def register(self, instance: T) -> str:
|
||||
with self._lock:
|
||||
obj_id = id(instance)
|
||||
if obj_id in self._id_map:
|
||||
return self._id_map[obj_id]
|
||||
instance_id = f"{self._type_prefix}_{self._counter}"
|
||||
self._counter += 1
|
||||
self._registry[instance_id] = instance
|
||||
self._id_map[obj_id] = instance_id
|
||||
return instance_id
|
||||
|
||||
def unregister_sync(self, instance_id: str) -> None:
|
||||
with self._lock:
|
||||
instance = self._registry.pop(instance_id, None)
|
||||
if instance:
|
||||
self._id_map.pop(id(instance), None)
|
||||
|
||||
def _get_instance(self, instance_id: str) -> T:
|
||||
if IS_CHILD_PROCESS:
|
||||
raise RuntimeError(
|
||||
f"[{self.__class__.__name__}] _get_instance called in child"
|
||||
)
|
||||
with self._lock:
|
||||
instance = self._registry.get(instance_id)
|
||||
if instance is None:
|
||||
raise ValueError(f"{instance_id} not found")
|
||||
return instance
|
||||
|
||||
|
||||
_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
|
||||
def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
global _GLOBAL_LOOP
|
||||
_GLOBAL_LOOP = loop
|
||||
|
||||
|
||||
def run_sync_rpc_coro(coro: Any, timeout_ms: Optional[int] = None) -> Any:
|
||||
if timeout_ms is not None:
|
||||
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
|
||||
|
||||
try:
|
||||
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
|
||||
try:
|
||||
curr_loop = asyncio.get_running_loop()
|
||||
if curr_loop is _GLOBAL_LOOP:
|
||||
pass
|
||||
except RuntimeError:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
|
||||
return future.result(
|
||||
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(coro)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(coro)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc
|
||||
except concurrent.futures.TimeoutError as exc:
|
||||
raise TimeoutError(f"Isolation RPC timeout (timeout_ms={timeout_ms})") from exc
|
||||
|
||||
|
||||
def call_singleton_rpc(
|
||||
caller: Any,
|
||||
method_name: str,
|
||||
*args: Any,
|
||||
timeout_ms: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if caller is None:
|
||||
raise RuntimeError(f"No RPC caller available for {method_name}")
|
||||
method = getattr(caller, method_name)
|
||||
return run_sync_rpc_coro(method(*args, **kwargs), timeout_ms=timeout_ms)
|
||||
|
||||
|
||||
class BaseProxy(Generic[T]):
|
||||
_registry_class: type = BaseRegistry # type: ignore[type-arg]
|
||||
__module__: str = "comfy.isolation.proxies.base"
|
||||
_TIMEOUT_RPC_METHODS = frozenset(
|
||||
{
|
||||
"partially_load",
|
||||
"partially_unload",
|
||||
"load",
|
||||
"patch_model",
|
||||
"unpatch_model",
|
||||
"inner_model_apply_model",
|
||||
"memory_required",
|
||||
"model_dtype",
|
||||
"inner_model_memory_required",
|
||||
"inner_model_extra_conds_shapes",
|
||||
"inner_model_extra_conds",
|
||||
"process_latent_in",
|
||||
"process_latent_out",
|
||||
"scale_latent_inpaint",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
registry: Optional[Any] = None,
|
||||
manage_lifecycle: bool = False,
|
||||
) -> None:
|
||||
self._instance_id = instance_id
|
||||
self._rpc_caller: Optional[Any] = None
|
||||
self._registry = registry if registry is not None else self._registry_class()
|
||||
self._manage_lifecycle = manage_lifecycle
|
||||
self._cleaned_up = False
|
||||
if manage_lifecycle and not IS_CHILD_PROCESS:
|
||||
self._finalizer = weakref.finalize(
|
||||
self, self._registry.unregister_sync, instance_id
|
||||
)
|
||||
|
||||
def _get_rpc(self) -> Any:
|
||||
if self._rpc_caller is None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
if rpc is None:
|
||||
raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child")
|
||||
self._rpc_caller = rpc.create_caller(
|
||||
self._registry_class, self._registry_class.get_remote_id()
|
||||
)
|
||||
return self._rpc_caller
|
||||
|
||||
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
|
||||
if method_name not in self._TIMEOUT_RPC_METHODS:
|
||||
return None
|
||||
try:
|
||||
timeout_ms = int(
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
|
||||
)
|
||||
except ValueError:
|
||||
timeout_ms = 120000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
|
||||
coro = method(self._instance_id, *args, **kwargs)
|
||||
if timeout_ms is not None:
|
||||
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
|
||||
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
try:
|
||||
running_loop = asyncio.get_running_loop()
|
||||
loop_id: Optional[int] = id(running_loop)
|
||||
except RuntimeError:
|
||||
loop_id = None
|
||||
logger.debug(
|
||||
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
|
||||
"thread=%s loop=%s timeout_ms=%s",
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self._instance_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
loop_id,
|
||||
timeout_ms,
|
||||
)
|
||||
|
||||
try:
|
||||
return run_sync_rpc_coro(coro, timeout_ms=timeout_ms)
|
||||
except TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
|
||||
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
|
||||
) from exc
|
||||
finally:
|
||||
end_epoch = time.time()
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
|
||||
"elapsed_ms=%.3f thread=%s loop=%s",
|
||||
self.__class__.__name__,
|
||||
method_name,
|
||||
self._instance_id,
|
||||
end_epoch,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
loop_id,
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
return {"_instance_id": self._instance_id}
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
self._instance_id = state["_instance_id"]
|
||||
self._rpc_caller = None
|
||||
self._registry = self._registry_class()
|
||||
self._manage_lifecycle = False
|
||||
self._cleaned_up = False
|
||||
|
||||
def cleanup(self) -> None:
|
||||
if self._cleaned_up or IS_CHILD_PROCESS:
|
||||
return
|
||||
self._cleaned_up = True
|
||||
finalizer = getattr(self, "_finalizer", None)
|
||||
if finalizer is not None:
|
||||
finalizer.detach()
|
||||
self._registry.unregister_sync(self._instance_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} {self._instance_id}>"
|
||||
|
||||
|
||||
def create_rpc_method(method_name: str) -> Callable[..., Any]:
|
||||
def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
return self._call_rpc(method_name, *args, **kwargs)
|
||||
|
||||
method.__name__ = method_name
|
||||
return method
|
||||
202
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
202
comfy/isolation/proxies/folder_paths_proxy.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
from .base import call_singleton_rpc
|
||||
|
||||
|
||||
def _folder_paths():
|
||||
import folder_paths
|
||||
|
||||
return folder_paths
|
||||
|
||||
|
||||
def _is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
def _serialize_folder_names_and_paths(data: dict[str, tuple[list[str], set[str]]]) -> dict[str, dict[str, list[str]]]:
|
||||
return {
|
||||
key: {"paths": list(paths), "extensions": sorted(list(extensions))}
|
||||
for key, (paths, extensions) in data.items()
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_folder_names_and_paths(data: dict[str, dict[str, list[str]]]) -> dict[str, tuple[list[str], set[str]]]:
|
||||
return {
|
||||
key: (list(value.get("paths", [])), set(value.get("extensions", [])))
|
||||
for key, value in data.items()
|
||||
}
|
||||
|
||||
|
||||
class FolderPathsProxy(ProxiedSingleton):
|
||||
"""
|
||||
Dynamic proxy for folder_paths.
|
||||
Uses __getattr__ for most lookups, with explicit handling for
|
||||
mutable collections to ensure efficient by-value transfer.
|
||||
"""
|
||||
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
@classmethod
|
||||
def _get_caller(cls) -> Any:
|
||||
if cls._rpc is None:
|
||||
raise RuntimeError("FolderPathsProxy RPC caller is not configured")
|
||||
return cls._rpc
|
||||
|
||||
def __getattr__(self, name):
|
||||
if _is_child_process():
|
||||
property_rpc = {
|
||||
"models_dir": "rpc_get_models_dir",
|
||||
"folder_names_and_paths": "rpc_get_folder_names_and_paths",
|
||||
"extension_mimetypes_cache": "rpc_get_extension_mimetypes_cache",
|
||||
"filename_list_cache": "rpc_get_filename_list_cache",
|
||||
}
|
||||
rpc_name = property_rpc.get(name)
|
||||
if rpc_name is not None:
|
||||
return call_singleton_rpc(self._get_caller(), rpc_name)
|
||||
raise AttributeError(name)
|
||||
return getattr(_folder_paths(), name)
|
||||
|
||||
@property
|
||||
def folder_names_and_paths(self) -> Dict:
|
||||
if _is_child_process():
|
||||
payload = call_singleton_rpc(self._get_caller(), "rpc_get_folder_names_and_paths")
|
||||
return _deserialize_folder_names_and_paths(payload)
|
||||
return _folder_paths().folder_names_and_paths
|
||||
|
||||
@property
|
||||
def extension_mimetypes_cache(self) -> Dict:
|
||||
if _is_child_process():
|
||||
return dict(call_singleton_rpc(self._get_caller(), "rpc_get_extension_mimetypes_cache"))
|
||||
return dict(_folder_paths().extension_mimetypes_cache)
|
||||
|
||||
@property
|
||||
def filename_list_cache(self) -> Dict:
|
||||
if _is_child_process():
|
||||
return dict(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list_cache"))
|
||||
return dict(_folder_paths().filename_list_cache)
|
||||
|
||||
@property
|
||||
def models_dir(self) -> str:
|
||||
if _is_child_process():
|
||||
return str(call_singleton_rpc(self._get_caller(), "rpc_get_models_dir"))
|
||||
return _folder_paths().models_dir
|
||||
|
||||
def get_temp_directory(self) -> str:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(self._get_caller(), "rpc_get_temp_directory")
|
||||
return _folder_paths().get_temp_directory()
|
||||
|
||||
def get_input_directory(self) -> str:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(self._get_caller(), "rpc_get_input_directory")
|
||||
return _folder_paths().get_input_directory()
|
||||
|
||||
def get_output_directory(self) -> str:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(self._get_caller(), "rpc_get_output_directory")
|
||||
return _folder_paths().get_output_directory()
|
||||
|
||||
def get_user_directory(self) -> str:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(self._get_caller(), "rpc_get_user_directory")
|
||||
return _folder_paths().get_user_directory()
|
||||
|
||||
def get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(
|
||||
self._get_caller(), "rpc_get_annotated_filepath", name, default_dir
|
||||
)
|
||||
return _folder_paths().get_annotated_filepath(name, default_dir)
|
||||
|
||||
def exists_annotated_filepath(self, name: str) -> bool:
|
||||
if _is_child_process():
|
||||
return bool(
|
||||
call_singleton_rpc(self._get_caller(), "rpc_exists_annotated_filepath", name)
|
||||
)
|
||||
return bool(_folder_paths().exists_annotated_filepath(name))
|
||||
|
||||
def add_model_folder_path(
|
||||
self, folder_name: str, full_folder_path: str, is_default: bool = False
|
||||
) -> None:
|
||||
if _is_child_process():
|
||||
call_singleton_rpc(
|
||||
self._get_caller(),
|
||||
"rpc_add_model_folder_path",
|
||||
folder_name,
|
||||
full_folder_path,
|
||||
is_default,
|
||||
)
|
||||
return None
|
||||
_folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default)
|
||||
return None
|
||||
|
||||
def get_folder_paths(self, folder_name: str) -> list[str]:
|
||||
if _is_child_process():
|
||||
return list(call_singleton_rpc(self._get_caller(), "rpc_get_folder_paths", folder_name))
|
||||
return list(_folder_paths().get_folder_paths(folder_name))
|
||||
|
||||
def get_filename_list(self, folder_name: str) -> list[str]:
|
||||
if _is_child_process():
|
||||
return list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name))
|
||||
return list(_folder_paths().get_filename_list(folder_name))
|
||||
|
||||
def get_full_path(self, folder_name: str, filename: str) -> str | None:
|
||||
if _is_child_process():
|
||||
return call_singleton_rpc(self._get_caller(), "rpc_get_full_path", folder_name, filename)
|
||||
return _folder_paths().get_full_path(folder_name, filename)
|
||||
|
||||
async def rpc_get_models_dir(self) -> str:
|
||||
return _folder_paths().models_dir
|
||||
|
||||
async def rpc_get_folder_names_and_paths(self) -> dict[str, dict[str, list[str]]]:
|
||||
return _serialize_folder_names_and_paths(_folder_paths().folder_names_and_paths)
|
||||
|
||||
async def rpc_get_extension_mimetypes_cache(self) -> dict[str, Any]:
|
||||
return dict(_folder_paths().extension_mimetypes_cache)
|
||||
|
||||
async def rpc_get_filename_list_cache(self) -> dict[str, Any]:
|
||||
return dict(_folder_paths().filename_list_cache)
|
||||
|
||||
async def rpc_get_temp_directory(self) -> str:
|
||||
return _folder_paths().get_temp_directory()
|
||||
|
||||
async def rpc_get_input_directory(self) -> str:
|
||||
return _folder_paths().get_input_directory()
|
||||
|
||||
async def rpc_get_output_directory(self) -> str:
|
||||
return _folder_paths().get_output_directory()
|
||||
|
||||
async def rpc_get_user_directory(self) -> str:
|
||||
return _folder_paths().get_user_directory()
|
||||
|
||||
async def rpc_get_annotated_filepath(self, name: str, default_dir: str | None = None) -> str:
|
||||
return _folder_paths().get_annotated_filepath(name, default_dir)
|
||||
|
||||
async def rpc_exists_annotated_filepath(self, name: str) -> bool:
|
||||
return _folder_paths().exists_annotated_filepath(name)
|
||||
|
||||
async def rpc_add_model_folder_path(
|
||||
self, folder_name: str, full_folder_path: str, is_default: bool = False
|
||||
) -> None:
|
||||
_folder_paths().add_model_folder_path(folder_name, full_folder_path, is_default)
|
||||
|
||||
async def rpc_get_folder_paths(self, folder_name: str) -> list[str]:
|
||||
return _folder_paths().get_folder_paths(folder_name)
|
||||
|
||||
async def rpc_get_filename_list(self, folder_name: str) -> list[str]:
|
||||
return _folder_paths().get_filename_list(folder_name)
|
||||
|
||||
async def rpc_get_full_path(self, folder_name: str, filename: str) -> str | None:
|
||||
return _folder_paths().get_full_path(folder_name, filename)
|
||||
158
comfy/isolation/proxies/helper_proxies.py
Normal file
158
comfy/isolation/proxies/helper_proxies.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
from .base import call_singleton_rpc
|
||||
|
||||
|
||||
class AnyTypeProxy(str):
|
||||
"""Replacement for custom AnyType objects used by some nodes."""
|
||||
|
||||
def __new__(cls, value: str = "*"):
|
||||
return super().__new__(cls, value)
|
||||
|
||||
def __ne__(self, other): # type: ignore[override]
|
||||
return False
|
||||
|
||||
|
||||
class FlexibleOptionalInputProxy(dict):
|
||||
"""Replacement for FlexibleOptionalInputType to allow dynamic inputs."""
|
||||
|
||||
def __init__(self, flex_type, data: Optional[Dict[str, object]] = None):
|
||||
super().__init__()
|
||||
self.type = flex_type
|
||||
if data:
|
||||
self.update(data)
|
||||
|
||||
def __getitem__(self, key): # type: ignore[override]
|
||||
return (self.type,)
|
||||
|
||||
def __contains__(self, key): # type: ignore[override]
|
||||
return True
|
||||
|
||||
|
||||
class ByPassTypeTupleProxy(tuple):
|
||||
"""Replacement for ByPassTypeTuple to mirror wildcard fallback behavior."""
|
||||
|
||||
def __new__(cls, values):
|
||||
return super().__new__(cls, values)
|
||||
|
||||
def __getitem__(self, index): # type: ignore[override]
|
||||
if index >= len(self):
|
||||
return AnyTypeProxy("*")
|
||||
return super().__getitem__(index)
|
||||
|
||||
|
||||
def _restore_special_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
if value.get("__pyisolate_any_type__"):
|
||||
return AnyTypeProxy(value.get("value", "*"))
|
||||
if value.get("__pyisolate_flexible_optional__"):
|
||||
flex_type = _restore_special_value(value.get("type"))
|
||||
data_raw = value.get("data")
|
||||
data = (
|
||||
{k: _restore_special_value(v) for k, v in data_raw.items()}
|
||||
if isinstance(data_raw, dict)
|
||||
else {}
|
||||
)
|
||||
return FlexibleOptionalInputProxy(flex_type, data)
|
||||
if value.get("__pyisolate_tuple__") is not None:
|
||||
return tuple(
|
||||
_restore_special_value(v) for v in value["__pyisolate_tuple__"]
|
||||
)
|
||||
if value.get("__pyisolate_bypass_tuple__") is not None:
|
||||
return ByPassTypeTupleProxy(
|
||||
tuple(
|
||||
_restore_special_value(v)
|
||||
for v in value["__pyisolate_bypass_tuple__"]
|
||||
)
|
||||
)
|
||||
return {k: _restore_special_value(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_restore_special_value(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_special_value(value: Any) -> Any:
|
||||
if isinstance(value, AnyTypeProxy):
|
||||
return {"__pyisolate_any_type__": True, "value": str(value)}
|
||||
if isinstance(value, FlexibleOptionalInputProxy):
|
||||
return {
|
||||
"__pyisolate_flexible_optional__": True,
|
||||
"type": _serialize_special_value(value.type),
|
||||
"data": {k: _serialize_special_value(v) for k, v in value.items()},
|
||||
}
|
||||
if isinstance(value, ByPassTypeTupleProxy):
|
||||
return {
|
||||
"__pyisolate_bypass_tuple__": [_serialize_special_value(v) for v in value]
|
||||
}
|
||||
if isinstance(value, tuple):
|
||||
return {"__pyisolate_tuple__": [_serialize_special_value(v) for v in value]}
|
||||
if isinstance(value, list):
|
||||
return [_serialize_special_value(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {k: _serialize_special_value(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def _restore_input_types_local(raw: Dict[str, object]) -> Dict[str, object]:
|
||||
if not isinstance(raw, dict):
|
||||
return raw # type: ignore[return-value]
|
||||
|
||||
restored: Dict[str, object] = {}
|
||||
for section, entries in raw.items():
|
||||
if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"):
|
||||
restored[section] = _restore_special_value(entries)
|
||||
elif isinstance(entries, dict):
|
||||
restored[section] = {
|
||||
k: _restore_special_value(v) for k, v in entries.items()
|
||||
}
|
||||
else:
|
||||
restored[section] = _restore_special_value(entries)
|
||||
return restored
|
||||
|
||||
|
||||
class HelperProxiesService(ProxiedSingleton):
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
@classmethod
|
||||
def _get_caller(cls) -> Any:
|
||||
if cls._rpc is None:
|
||||
raise RuntimeError("HelperProxiesService RPC caller is not configured")
|
||||
return cls._rpc
|
||||
|
||||
async def rpc_restore_input_types(self, raw: Dict[str, object]) -> Dict[str, object]:
|
||||
restored = _restore_input_types_local(raw)
|
||||
return _serialize_special_value(restored)
|
||||
|
||||
|
||||
def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]:
|
||||
"""Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects."""
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
payload = call_singleton_rpc(
|
||||
HelperProxiesService._get_caller(),
|
||||
"rpc_restore_input_types",
|
||||
raw,
|
||||
)
|
||||
return _restore_input_types_local(payload)
|
||||
return _restore_input_types_local(raw)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AnyTypeProxy",
|
||||
"FlexibleOptionalInputProxy",
|
||||
"ByPassTypeTupleProxy",
|
||||
"HelperProxiesService",
|
||||
"restore_input_types",
|
||||
]
|
||||
142
comfy/isolation/proxies/model_management_proxy.py
Normal file
142
comfy/isolation/proxies/model_management_proxy.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
from .base import call_singleton_rpc
|
||||
|
||||
|
||||
def _mm():
|
||||
import comfy.model_management
|
||||
|
||||
return comfy.model_management
|
||||
|
||||
|
||||
def _is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
class TorchDeviceProxy:
|
||||
def __init__(self, device_str: str):
|
||||
self._device_str = device_str
|
||||
if ":" in device_str:
|
||||
device_type, index = device_str.split(":", 1)
|
||||
self.type = device_type
|
||||
self.index = int(index)
|
||||
else:
|
||||
self.type = device_str
|
||||
self.index = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._device_str
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TorchDeviceProxy({self._device_str!r})"
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
value_type = type(value)
|
||||
if value_type.__module__ == "torch" and value_type.__name__ == "device":
|
||||
return {"__pyisolate_torch_device__": str(value)}
|
||||
if isinstance(value, TorchDeviceProxy):
|
||||
return {"__pyisolate_torch_device__": str(value)}
|
||||
if isinstance(value, tuple):
|
||||
return {"__pyisolate_tuple__": [_serialize_value(item) for item in value]}
|
||||
if isinstance(value, list):
|
||||
return [_serialize_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: _serialize_value(inner) for key, inner in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def _deserialize_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
if "__pyisolate_torch_device__" in value:
|
||||
return TorchDeviceProxy(value["__pyisolate_torch_device__"])
|
||||
if "__pyisolate_tuple__" in value:
|
||||
return tuple(_deserialize_value(item) for item in value["__pyisolate_tuple__"])
|
||||
return {key: _deserialize_value(inner) for key, inner in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_deserialize_value(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _normalize_argument(value: Any) -> Any:
|
||||
if isinstance(value, TorchDeviceProxy):
|
||||
import torch
|
||||
|
||||
return torch.device(str(value))
|
||||
if isinstance(value, dict):
|
||||
if "__pyisolate_torch_device__" in value:
|
||||
import torch
|
||||
|
||||
return torch.device(value["__pyisolate_torch_device__"])
|
||||
if "__pyisolate_tuple__" in value:
|
||||
return tuple(_normalize_argument(item) for item in value["__pyisolate_tuple__"])
|
||||
return {key: _normalize_argument(inner) for key, inner in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_normalize_argument(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
class ModelManagementProxy(ProxiedSingleton):
|
||||
"""
|
||||
Exact-relay proxy for comfy.model_management.
|
||||
Child calls never import comfy.model_management directly; they serialize
|
||||
arguments, relay to host, and deserialize the host result back.
|
||||
"""
|
||||
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
@classmethod
|
||||
def _get_caller(cls) -> Any:
|
||||
if cls._rpc is None:
|
||||
raise RuntimeError("ModelManagementProxy RPC caller is not configured")
|
||||
return cls._rpc
|
||||
|
||||
def _relay_call(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
payload = call_singleton_rpc(
|
||||
self._get_caller(),
|
||||
"rpc_call",
|
||||
method_name,
|
||||
_serialize_value(args),
|
||||
_serialize_value(kwargs),
|
||||
)
|
||||
return _deserialize_value(payload)
|
||||
|
||||
@property
|
||||
def VRAMState(self):
|
||||
return _mm().VRAMState
|
||||
|
||||
@property
|
||||
def CPUState(self):
|
||||
return _mm().CPUState
|
||||
|
||||
@property
|
||||
def OOM_EXCEPTION(self):
|
||||
return _mm().OOM_EXCEPTION
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if _is_child_process():
|
||||
def child_method(*args: Any, **kwargs: Any) -> Any:
|
||||
return self._relay_call(name, *args, **kwargs)
|
||||
|
||||
return child_method
|
||||
return getattr(_mm(), name)
|
||||
|
||||
async def rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any:
|
||||
normalized_args = _normalize_argument(_deserialize_value(args))
|
||||
normalized_kwargs = _normalize_argument(_deserialize_value(kwargs))
|
||||
method = getattr(_mm(), method_name)
|
||||
result = method(*normalized_args, **normalized_kwargs)
|
||||
return _serialize_value(result)
|
||||
87
comfy/isolation/proxies/progress_proxy.py
Normal file
87
comfy/isolation/proxies/progress_proxy.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
from pyisolate import ProxiedSingleton
|
||||
except ImportError:
|
||||
|
||||
class ProxiedSingleton:
|
||||
pass
|
||||
|
||||
from .base import call_singleton_rpc
|
||||
|
||||
|
||||
def _get_progress_state():
|
||||
from comfy_execution.progress import get_progress_state
|
||||
|
||||
return get_progress_state()
|
||||
|
||||
|
||||
def _is_child_process() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProgressProxy(ProxiedSingleton):
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
cls._rpc = rpc.create_caller(cls, cls.get_remote_id())
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
@classmethod
|
||||
def _get_caller(cls) -> Any:
|
||||
if cls._rpc is None:
|
||||
raise RuntimeError("ProgressProxy RPC caller is not configured")
|
||||
return cls._rpc
|
||||
|
||||
def set_progress(
|
||||
self,
|
||||
value: float,
|
||||
max_value: float,
|
||||
node_id: Optional[str] = None,
|
||||
image: Any = None,
|
||||
) -> None:
|
||||
if _is_child_process():
|
||||
call_singleton_rpc(
|
||||
self._get_caller(),
|
||||
"rpc_set_progress",
|
||||
value,
|
||||
max_value,
|
||||
node_id,
|
||||
image,
|
||||
)
|
||||
return None
|
||||
|
||||
_get_progress_state().update_progress(
|
||||
node_id=node_id,
|
||||
value=value,
|
||||
max_value=max_value,
|
||||
image=image,
|
||||
)
|
||||
return None
|
||||
|
||||
async def rpc_set_progress(
|
||||
self,
|
||||
value: float,
|
||||
max_value: float,
|
||||
node_id: Optional[str] = None,
|
||||
image: Any = None,
|
||||
) -> None:
|
||||
_get_progress_state().update_progress(
|
||||
node_id=node_id,
|
||||
value=value,
|
||||
max_value=max_value,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ProgressProxy"]
|
||||
271
comfy/isolation/proxies/prompt_server_impl.py
Normal file
271
comfy/isolation/proxies/prompt_server_impl.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called
|
||||
"""Stateless RPC Implementation for PromptServer.
|
||||
|
||||
Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture.
|
||||
- Host: PromptServerService (RPC Handler)
|
||||
- Child: PromptServerStub (Interface Implementation)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
import logging
|
||||
|
||||
# IMPORTS
|
||||
from pyisolate import ProxiedSingleton
|
||||
from .base import call_singleton_rpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_PREFIX = "[Isolation:C<->H]"
|
||||
|
||||
# ...
|
||||
|
||||
# =============================================================================
|
||||
# CHILD SIDE: PromptServerStub
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptServerStub:
|
||||
"""Stateless Stub for PromptServer."""
|
||||
|
||||
# Masquerade as the real server module
|
||||
__module__ = "server"
|
||||
|
||||
_instance: Optional["PromptServerStub"] = None
|
||||
_rpc: Optional[Any] = None # This will be the Caller object
|
||||
_source_file: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.routes = RouteStub(self)
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
"""Inject RPC client (called by adapter.py or manually)."""
|
||||
# Create caller for HOST Service
|
||||
# Assuming Host Service is registered as "PromptServerService" (class name)
|
||||
# We target the Host Service Class
|
||||
target_id = "PromptServerService"
|
||||
# We need to pass a class to create_caller? Usually yes.
|
||||
# But we don't have the Service class imported here necessarily (if running on child).
|
||||
# pyisolate check verify_service type?
|
||||
# If we pass PromptServerStub as the 'class', it might mismatch if checking types.
|
||||
# But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub.
|
||||
# We need a dummy class with right name?
|
||||
# Or just rely on string ID if create_caller supports it?
|
||||
# Standard: rpc.create_caller(PromptServerStub, target_id)
|
||||
# But wait, PromptServerStub is the *Local* class.
|
||||
# We want to call *Remote* class.
|
||||
# If we use PromptServerStub as the type, returning object will be typed as PromptServerStub?
|
||||
# The first arg is 'service_cls'.
|
||||
cls._rpc = rpc.create_caller(
|
||||
PromptServerService, target_id
|
||||
) # We import Service below?
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
# We need PromptServerService available for the create_caller call?
|
||||
# Or just use the Stub class if ID matches?
|
||||
# prompt_server_impl.py defines BOTH. So PromptServerService IS available!
|
||||
|
||||
@property
|
||||
def instance(self) -> "PromptServerStub":
|
||||
return self
|
||||
|
||||
# ... Compatibility ...
|
||||
@classmethod
|
||||
def _get_source_file(cls) -> str:
|
||||
if cls._source_file is None:
|
||||
import folder_paths
|
||||
|
||||
cls._source_file = os.path.join(folder_paths.base_path, "server.py")
|
||||
return cls._source_file
|
||||
|
||||
@property
|
||||
def __file__(self) -> str:
|
||||
return self._get_source_file()
|
||||
|
||||
# --- Properties ---
|
||||
@property
|
||||
def client_id(self) -> Optional[str]:
|
||||
return "isolated_client"
|
||||
|
||||
def supports(self, feature: str) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def prompt_queue(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.prompt_queue is not accessible in isolated nodes."
|
||||
)
|
||||
|
||||
# --- UI Communication (RPC Delegates) ---
|
||||
async def send_sync(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
) -> None:
|
||||
if self._rpc:
|
||||
await self._rpc.ui_send_sync(event, data, sid)
|
||||
|
||||
async def send(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
) -> None:
|
||||
if self._rpc:
|
||||
await self._rpc.ui_send(event, data, sid)
|
||||
|
||||
def send_progress_text(self, text: str, node_id: str, sid=None) -> None:
|
||||
if self._rpc:
|
||||
# Fire and forget likely needed. If method is async on host, caller invocation returns coroutine.
|
||||
# We must schedule it?
|
||||
# Or use fire_remote equivalent?
|
||||
# Caller object usually proxies calls. If host method is async, it returns coro.
|
||||
# If we are sync here (send_progress_text checks imply sync usage), we must background it.
|
||||
# But UtilsProxy hook wrapper creates task.
|
||||
# Does send_progress_text need to be sync? Yes, node code calls it sync.
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid))
|
||||
except RuntimeError:
|
||||
call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid)
|
||||
|
||||
# --- Route Registration Logic ---
|
||||
def register_route(self, method: str, path: str, handler: Callable):
|
||||
"""Register a route handler via RPC."""
|
||||
if not self._rpc:
|
||||
logger.error("RPC not initialized in PromptServerStub")
|
||||
return
|
||||
|
||||
# Fire registration async
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
|
||||
except RuntimeError:
|
||||
call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler)
|
||||
|
||||
|
||||
class RouteStub:
|
||||
"""Simulates aiohttp.web.RouteTableDef."""
|
||||
|
||||
def __init__(self, stub: PromptServerStub):
|
||||
self._stub = stub
|
||||
|
||||
def get(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(self, path: str):
|
||||
def decorator(handler):
|
||||
self._stub.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HOST SIDE: PromptServerService
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptServerService(ProxiedSingleton):
|
||||
"""Host-side RPC Service for PromptServer."""
|
||||
|
||||
def __init__(self):
|
||||
# We will bind to the real server instance lazily or via global import
|
||||
pass
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
from server import PromptServer
|
||||
|
||||
return PromptServer.instance
|
||||
|
||||
async def ui_send_sync(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
):
|
||||
await self.server.send_sync(event, data, sid)
|
||||
|
||||
async def ui_send(
|
||||
self, event: str, data: Dict[str, Any], sid: Optional[str] = None
|
||||
):
|
||||
await self.server.send(event, data, sid)
|
||||
|
||||
async def ui_send_progress_text(self, text: str, node_id: str, sid=None):
|
||||
# Made async to be awaitable by RPC layer
|
||||
self.server.send_progress_text(text, node_id, sid)
|
||||
|
||||
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
|
||||
"""RPC Target: Register a route that forwards to the Child."""
|
||||
from aiohttp import web
|
||||
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
|
||||
|
||||
async def route_wrapper(request: web.Request) -> web.Response:
|
||||
# 1. Capture request data
|
||||
req_data = {
|
||||
"method": request.method,
|
||||
"path": request.path,
|
||||
"query": dict(request.query),
|
||||
}
|
||||
if request.can_read_body:
|
||||
req_data["text"] = await request.text()
|
||||
|
||||
try:
|
||||
# 2. Call Child Handler via RPC (child_handler_proxy is async callable)
|
||||
result = await child_handler_proxy(req_data)
|
||||
|
||||
# 3. Serialize Response
|
||||
return self._serialize_response(result)
|
||||
except Exception as e:
|
||||
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
# Register loop
|
||||
self.server.app.router.add_route(method, path, route_wrapper)
|
||||
|
||||
def _serialize_response(self, result: Any) -> Any:
|
||||
"""Helper to convert Child result -> web.Response"""
|
||||
from aiohttp import web
|
||||
if isinstance(result, web.Response):
|
||||
return result
|
||||
# Handle dict (json)
|
||||
if isinstance(result, dict):
|
||||
return web.json_response(result)
|
||||
# Handle string
|
||||
if isinstance(result, str):
|
||||
return web.Response(text=result)
|
||||
# Fallback
|
||||
return web.Response(text=str(result))
|
||||
64
comfy/isolation/proxies/utils_proxy.py
Normal file
64
comfy/isolation/proxies/utils_proxy.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Any
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def _comfy_utils():
|
||||
import comfy.utils
|
||||
return comfy.utils
|
||||
|
||||
|
||||
class UtilsProxy(ProxiedSingleton):
|
||||
"""
|
||||
Proxy for comfy.utils.
|
||||
Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates
|
||||
from isolated nodes reach the host.
|
||||
"""
|
||||
|
||||
# _instance and __new__ removed to rely on SingletonMetaclass
|
||||
_rpc: Optional[Any] = None
|
||||
|
||||
@classmethod
|
||||
def set_rpc(cls, rpc: Any) -> None:
|
||||
# Create caller using class name as ID (standard for Singletons)
|
||||
cls._rpc = rpc.create_caller(cls, "UtilsProxy")
|
||||
|
||||
@classmethod
|
||||
def clear_rpc(cls) -> None:
|
||||
cls._rpc = None
|
||||
|
||||
async def progress_bar_hook(
|
||||
self,
|
||||
value: int,
|
||||
total: int,
|
||||
preview: Optional[bytes] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Host-side implementation: forwards the call to the real global hook.
|
||||
Child-side: this method call is intercepted by RPC and sent to host.
|
||||
"""
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
if UtilsProxy._rpc is None:
|
||||
raise RuntimeError("UtilsProxy RPC caller is not configured")
|
||||
return await UtilsProxy._rpc.progress_bar_hook(
|
||||
value, total, preview, node_id
|
||||
)
|
||||
|
||||
# Host Execution
|
||||
utils = _comfy_utils()
|
||||
if utils.PROGRESS_BAR_HOOK is not None:
|
||||
return utils.PROGRESS_BAR_HOOK(value, total, preview, node_id)
|
||||
return None
|
||||
|
||||
def set_progress_bar_global_hook(self, hook: Any) -> None:
|
||||
"""Forward hook registration (though usually not needed from child)."""
|
||||
if os.environ.get("PYISOLATE_CHILD") == "1":
|
||||
raise RuntimeError(
|
||||
"UtilsProxy.set_progress_bar_global_hook is not available in child without exact relay support"
|
||||
)
|
||||
_comfy_utils().set_progress_bar_global_hook(hook)
|
||||
219
comfy/isolation/proxies/web_directory_proxy.py
Normal file
219
comfy/isolation/proxies/web_directory_proxy.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""WebDirectoryProxy — serves isolated node web assets via RPC.
|
||||
|
||||
Child side: enumerates and reads files from the extension's web/ directory.
|
||||
Host side: gets an RPC proxy that fetches file listings and contents on demand.
|
||||
|
||||
Only files with allowed extensions (.js, .html, .css) are served.
|
||||
Directory traversal is rejected. File contents are base64-encoded for
|
||||
safe JSON-RPC transport.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pyisolate import ProxiedSingleton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_EXTENSIONS = frozenset({".js", ".html", ".css"})
|
||||
|
||||
MIME_TYPES = {
|
||||
".js": "application/javascript",
|
||||
".html": "text/html",
|
||||
".css": "text/css",
|
||||
}
|
||||
|
||||
|
||||
class WebDirectoryProxy(ProxiedSingleton):
|
||||
"""Proxy for serving isolated extension web directories.
|
||||
|
||||
On the child side, this class has direct filesystem access to the
|
||||
extension's web/ directory. On the host side, callers get an RPC
|
||||
proxy whose method calls are forwarded to the child.
|
||||
"""
|
||||
|
||||
# {extension_name: absolute_path_to_web_dir}
|
||||
_web_dirs: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def register_web_dir(cls, extension_name: str, web_dir_path: str) -> None:
|
||||
"""Register an extension's web directory (child-side only)."""
|
||||
cls._web_dirs[extension_name] = web_dir_path
|
||||
logger.info(
|
||||
"][ WebDirectoryProxy: registered %s -> %s",
|
||||
extension_name,
|
||||
web_dir_path,
|
||||
)
|
||||
|
||||
def list_web_files(self, extension_name: str) -> List[Dict[str, str]]:
|
||||
"""Return a list of servable files in the extension's web directory.
|
||||
|
||||
Each entry is {"relative_path": "js/foo.js", "content_type": "application/javascript"}.
|
||||
Only files with allowed extensions are included.
|
||||
"""
|
||||
web_dir = self._web_dirs.get(extension_name)
|
||||
if not web_dir:
|
||||
return []
|
||||
|
||||
root = Path(web_dir)
|
||||
if not root.is_dir():
|
||||
return []
|
||||
|
||||
result: List[Dict[str, str]] = []
|
||||
for path in sorted(root.rglob("*")):
|
||||
if not path.is_file():
|
||||
continue
|
||||
ext = path.suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
continue
|
||||
rel = path.relative_to(root)
|
||||
result.append({
|
||||
"relative_path": str(PurePosixPath(rel)),
|
||||
"content_type": MIME_TYPES[ext],
|
||||
})
|
||||
return result
|
||||
|
||||
def get_web_file(
|
||||
self, extension_name: str, relative_path: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Return the contents of a single web file as base64.
|
||||
|
||||
Raises ValueError for traversal attempts or disallowed file types.
|
||||
Returns {"content": <base64 str>, "content_type": <MIME str>}.
|
||||
"""
|
||||
_validate_path(relative_path)
|
||||
|
||||
web_dir = self._web_dirs.get(extension_name)
|
||||
if not web_dir:
|
||||
raise FileNotFoundError(
|
||||
f"No web directory registered for {extension_name}"
|
||||
)
|
||||
|
||||
root = Path(web_dir)
|
||||
target = (root / relative_path).resolve()
|
||||
|
||||
# Ensure resolved path is under the web directory
|
||||
if not str(target).startswith(str(root.resolve())):
|
||||
raise ValueError(f"Path escapes web directory: {relative_path}")
|
||||
|
||||
if not target.is_file():
|
||||
raise FileNotFoundError(f"File not found: {relative_path}")
|
||||
|
||||
ext = target.suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise ValueError(f"Disallowed file type: {ext}")
|
||||
|
||||
content_type = MIME_TYPES[ext]
|
||||
raw = target.read_bytes()
|
||||
|
||||
return {
|
||||
"content": base64.b64encode(raw).decode("ascii"),
|
||||
"content_type": content_type,
|
||||
}
|
||||
|
||||
|
||||
def _validate_path(relative_path: str) -> None:
|
||||
"""Reject directory traversal and absolute paths."""
|
||||
if os.path.isabs(relative_path):
|
||||
raise ValueError(f"Absolute paths are not allowed: {relative_path}")
|
||||
if ".." in PurePosixPath(relative_path).parts:
|
||||
raise ValueError(f"Directory traversal is not allowed: {relative_path}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Host-side cache and aiohttp handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WebDirectoryCache:
|
||||
"""Host-side in-memory cache for proxied web directory contents.
|
||||
|
||||
Populated lazily via RPC calls to the child's WebDirectoryProxy.
|
||||
Once a file is cached, subsequent requests are served from memory.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# {extension_name: {relative_path: {"content": bytes, "content_type": str}}}
|
||||
self._file_cache: dict[str, dict[str, dict[str, Any]]] = {}
|
||||
# {extension_name: [{"relative_path": str, "content_type": str}, ...]}
|
||||
self._listing_cache: dict[str, list[dict[str, str]]] = {}
|
||||
# {extension_name: WebDirectoryProxy (RPC proxy instance)}
|
||||
self._proxies: dict[str, Any] = {}
|
||||
|
||||
def register_proxy(self, extension_name: str, proxy: Any) -> None:
|
||||
"""Register an RPC proxy for an extension's web directory."""
|
||||
self._proxies[extension_name] = proxy
|
||||
logger.info(
|
||||
"][ WebDirectoryCache: registered proxy for %s", extension_name
|
||||
)
|
||||
|
||||
@property
|
||||
def extension_names(self) -> list[str]:
|
||||
return list(self._proxies.keys())
|
||||
|
||||
def list_files(self, extension_name: str) -> list[dict[str, str]]:
|
||||
"""List servable files for an extension (cached after first call)."""
|
||||
if extension_name not in self._listing_cache:
|
||||
proxy = self._proxies.get(extension_name)
|
||||
if proxy is None:
|
||||
return []
|
||||
try:
|
||||
self._listing_cache[extension_name] = proxy.list_web_files(
|
||||
extension_name
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"][ WebDirectoryCache: failed to list files for %s",
|
||||
extension_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
return self._listing_cache[extension_name]
|
||||
|
||||
def get_file(
|
||||
self, extension_name: str, relative_path: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get file content (cached after first fetch). Returns None on miss."""
|
||||
ext_cache = self._file_cache.get(extension_name)
|
||||
if ext_cache and relative_path in ext_cache:
|
||||
return ext_cache[relative_path]
|
||||
|
||||
proxy = self._proxies.get(extension_name)
|
||||
if proxy is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = proxy.get_web_file(extension_name, relative_path)
|
||||
except (FileNotFoundError, ValueError):
|
||||
return None
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"][ WebDirectoryCache: failed to fetch %s/%s",
|
||||
extension_name,
|
||||
relative_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
decoded = {
|
||||
"content": base64.b64decode(result["content"]),
|
||||
"content_type": result["content_type"],
|
||||
}
|
||||
|
||||
if extension_name not in self._file_cache:
|
||||
self._file_cache[extension_name] = {}
|
||||
self._file_cache[extension_name][relative_path] = decoded
|
||||
return decoded
|
||||
|
||||
|
||||
# Global cache instance — populated during isolation loading
|
||||
_web_directory_cache = WebDirectoryCache()
|
||||
|
||||
|
||||
def get_web_directory_cache() -> WebDirectoryCache:
|
||||
return _web_directory_cache
|
||||
49
comfy/isolation/rpc_bridge.py
Normal file
49
comfy/isolation/rpc_bridge.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RpcBridge:
|
||||
"""Minimal helper to run coroutines synchronously inside isolated processes.
|
||||
|
||||
If an event loop is already running, the coroutine is executed on a fresh
|
||||
thread with its own loop to avoid nested run_until_complete errors.
|
||||
"""
|
||||
|
||||
def run_sync(self, maybe_coro):
|
||||
if not asyncio.iscoroutine(maybe_coro):
|
||||
return maybe_coro
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
result_container = {}
|
||||
exc_container = {}
|
||||
|
||||
def _runner():
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result_container["value"] = new_loop.run_until_complete(maybe_coro)
|
||||
except Exception as exc: # pragma: no cover
|
||||
exc_container["error"] = exc
|
||||
finally:
|
||||
try:
|
||||
new_loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_runner, daemon=True)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
if "error" in exc_container:
|
||||
raise exc_container["error"]
|
||||
return result_container.get("value")
|
||||
|
||||
return asyncio.run(maybe_coro)
|
||||
471
comfy/isolation/runtime_helpers.py
Normal file
471
comfy/isolation/runtime_helpers.py
Normal file
@@ -0,0 +1,471 @@
|
||||
# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, TYPE_CHECKING
|
||||
|
||||
from .proxies.helper_proxies import restore_input_types
|
||||
from .shm_forensics import scan_shm_forensics
|
||||
|
||||
_IMPORT_TORCH = os.environ.get("PYISOLATE_IMPORT_TORCH", "1") == "1"
|
||||
|
||||
_ComfyNodeInternal = object
|
||||
latest_io = None
|
||||
|
||||
if _IMPORT_TORCH:
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from comfy_api.latest import _io as latest_io
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .extension_wrapper import ComfyNodeExtension
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
class _RemoteObjectRegistryCaller:
|
||||
def __init__(self, extension: Any) -> None:
|
||||
self._extension = extension
|
||||
|
||||
def __getattr__(self, method_name: str) -> Any:
|
||||
async def _call(instance_id: str, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self._extension.call_remote_object_method(
|
||||
instance_id,
|
||||
method_name,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _call
|
||||
|
||||
|
||||
def _wrap_remote_handles_as_host_proxies(value: Any, extension: Any) -> Any:
|
||||
from pyisolate._internal.remote_handle import RemoteObjectHandle
|
||||
|
||||
if isinstance(value, RemoteObjectHandle):
|
||||
if value.type_name == "ModelPatcher":
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
proxy = ModelPatcherProxy(value.object_id, manage_lifecycle=False)
|
||||
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||
return proxy
|
||||
if value.type_name == "VAE":
|
||||
from comfy.isolation.vae_proxy import VAEProxy
|
||||
|
||||
proxy = VAEProxy(value.object_id, manage_lifecycle=False)
|
||||
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||
return proxy
|
||||
if value.type_name == "CLIP":
|
||||
from comfy.isolation.clip_proxy import CLIPProxy
|
||||
|
||||
proxy = CLIPProxy(value.object_id, manage_lifecycle=False)
|
||||
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||
return proxy
|
||||
if value.type_name == "ModelSampling":
|
||||
from comfy.isolation.model_sampling_proxy import ModelSamplingProxy
|
||||
|
||||
proxy = ModelSamplingProxy(value.object_id, manage_lifecycle=False)
|
||||
proxy._rpc_caller = _RemoteObjectRegistryCaller(extension) # type: ignore[attr-defined]
|
||||
proxy._pyisolate_remote_handle = value # type: ignore[attr-defined]
|
||||
return proxy
|
||||
return value
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
k: _wrap_remote_handles_as_host_proxies(v, extension) for k, v in value.items()
|
||||
}
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
wrapped = [_wrap_remote_handles_as_host_proxies(item, extension) for item in value]
|
||||
return type(value)(wrapped)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _resource_snapshot() -> Dict[str, int]:
|
||||
fd_count = -1
|
||||
shm_sender_files = 0
|
||||
try:
|
||||
fd_count = len(os.listdir("/proc/self/fd"))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
shm_root = Path("/dev/shm")
|
||||
if shm_root.exists():
|
||||
prefix = f"torch_{os.getpid()}_"
|
||||
shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*"))
|
||||
except Exception:
|
||||
pass
|
||||
return {"fd_count": fd_count, "shm_sender_files": shm_sender_files}
|
||||
|
||||
|
||||
def _tensor_transport_summary(value: Any) -> Dict[str, int]:
|
||||
summary: Dict[str, int] = {
|
||||
"tensor_count": 0,
|
||||
"cpu_tensors": 0,
|
||||
"cuda_tensors": 0,
|
||||
"shared_cpu_tensors": 0,
|
||||
"tensor_bytes": 0,
|
||||
}
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return summary
|
||||
|
||||
def visit(node: Any) -> None:
|
||||
if isinstance(node, torch.Tensor):
|
||||
summary["tensor_count"] += 1
|
||||
summary["tensor_bytes"] += int(node.numel() * node.element_size())
|
||||
if node.device.type == "cpu":
|
||||
summary["cpu_tensors"] += 1
|
||||
if node.is_shared():
|
||||
summary["shared_cpu_tensors"] += 1
|
||||
elif node.device.type == "cuda":
|
||||
summary["cuda_tensors"] += 1
|
||||
return
|
||||
if isinstance(node, dict):
|
||||
for v in node.values():
|
||||
visit(v)
|
||||
return
|
||||
if isinstance(node, (list, tuple)):
|
||||
for v in node:
|
||||
visit(v)
|
||||
|
||||
visit(value)
|
||||
return summary
|
||||
|
||||
|
||||
def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None:
|
||||
for key, value in inputs.items():
|
||||
key_text = str(key)
|
||||
if "unique_id" in key_text:
|
||||
return str(value)
|
||||
return None
|
||||
|
||||
|
||||
def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return
|
||||
if not callable(flush_tensor_keeper):
|
||||
return
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logger.debug(
|
||||
"%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed
|
||||
)
|
||||
|
||||
|
||||
def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None:
|
||||
import comfy.model_management as model_management
|
||||
|
||||
model_management.cleanup_models_gc()
|
||||
model_management.cleanup_models()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
if not hasattr(device, "type") or device.type == "cpu":
|
||||
return
|
||||
|
||||
required = max(
|
||||
model_management.minimum_inference_memory(),
|
||||
_PRE_EXEC_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=True)
|
||||
if model_management.get_free_memory(device) < required:
|
||||
model_management.free_memory(required, device, for_dynamic=False)
|
||||
model_management.cleanup_models()
|
||||
model_management.soft_empty_cache()
|
||||
logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required)
|
||||
|
||||
|
||||
def _detach_shared_cpu_tensors(value: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.device.type == "cpu" and value.is_shared():
|
||||
clone = value.clone()
|
||||
if value.requires_grad:
|
||||
clone.requires_grad_(True)
|
||||
return clone
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
return [_detach_shared_cpu_tensors(v) for v in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_detach_shared_cpu_tensors(v) for v in value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def build_stub_class(
|
||||
node_name: str,
|
||||
info: Dict[str, object],
|
||||
extension: "ComfyNodeExtension",
|
||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||
logger: logging.Logger,
|
||||
) -> type:
|
||||
if latest_io is None:
|
||||
raise RuntimeError("comfy_api.latest._io is required to build isolation stubs")
|
||||
is_v3 = bool(info.get("is_v3", False))
|
||||
function_name = "_pyisolate_execute"
|
||||
restored_input_types = restore_input_types(info.get("input_types", {}))
|
||||
|
||||
async def _execute(self, **inputs):
|
||||
from comfy.isolation import _RUNNING_EXTENSIONS
|
||||
|
||||
# Update BOTH the local dict AND the module-level dict
|
||||
running_extensions[extension.name] = extension
|
||||
_RUNNING_EXTENSIONS[extension.name] = extension
|
||||
prev_child = None
|
||||
node_unique_id = _extract_hidden_unique_id(inputs)
|
||||
summary = _tensor_transport_summary(inputs)
|
||||
resources = _resource_snapshot()
|
||||
logger.debug(
|
||||
"%s ISO:execute_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
summary["tensor_count"],
|
||||
summary["cpu_tensors"],
|
||||
summary["cuda_tensors"],
|
||||
summary["shared_cpu_tensors"],
|
||||
summary["tensor_bytes"],
|
||||
resources["fd_count"],
|
||||
resources["shm_sender_files"],
|
||||
)
|
||||
scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True)
|
||||
try:
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
_relieve_host_vram_pressure("RUNTIME:pre_execute", logger)
|
||||
scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True)
|
||||
from pyisolate._internal.model_serialization import (
|
||||
serialize_for_isolation,
|
||||
deserialize_from_isolation,
|
||||
)
|
||||
|
||||
prev_child = os.environ.pop("PYISOLATE_CHILD", None)
|
||||
logger.debug(
|
||||
"%s ISO:serialize_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
# Unwrap NodeOutput-like dicts before serialization.
|
||||
# OUTPUT_NODE nodes return {"ui": {...}, "result": (outputs...)}
|
||||
# and the executor may pass this dict as input to downstream nodes.
|
||||
unwrapped_inputs = {}
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, dict) and "result" in v and ("ui" in v or "__node_output__" in v):
|
||||
result = v.get("result")
|
||||
if isinstance(result, (tuple, list)) and len(result) > 0:
|
||||
unwrapped_inputs[k] = result[0]
|
||||
else:
|
||||
unwrapped_inputs[k] = result
|
||||
else:
|
||||
unwrapped_inputs[k] = v
|
||||
serialized = serialize_for_isolation(unwrapped_inputs)
|
||||
logger.debug(
|
||||
"%s ISO:serialize_done ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
logger.debug(
|
||||
"%s ISO:dispatch_start ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
result = await extension.execute_node(node_name, **serialized)
|
||||
logger.debug(
|
||||
"%s ISO:dispatch_done ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
# Reconstruct NodeOutput if the child serialized one
|
||||
if isinstance(result, dict) and result.get("__node_output__"):
|
||||
from comfy_api.latest import io as latest_io
|
||||
args_raw = result.get("args", ())
|
||||
deserialized_args = await deserialize_from_isolation(args_raw, extension)
|
||||
deserialized_args = _wrap_remote_handles_as_host_proxies(
|
||||
deserialized_args, extension
|
||||
)
|
||||
deserialized_args = _detach_shared_cpu_tensors(deserialized_args)
|
||||
ui_raw = result.get("ui")
|
||||
deserialized_ui = None
|
||||
if ui_raw is not None:
|
||||
deserialized_ui = await deserialize_from_isolation(ui_raw, extension)
|
||||
deserialized_ui = _wrap_remote_handles_as_host_proxies(
|
||||
deserialized_ui, extension
|
||||
)
|
||||
deserialized_ui = _detach_shared_cpu_tensors(deserialized_ui)
|
||||
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||
return latest_io.NodeOutput(
|
||||
*deserialized_args,
|
||||
ui=deserialized_ui,
|
||||
expand=result.get("expand"),
|
||||
block_execution=result.get("block_execution"),
|
||||
)
|
||||
# OUTPUT_NODE: if sealed worker returned a tuple/list whose first
|
||||
# element is a {"ui": ...} dict, unwrap it for the executor.
|
||||
if (isinstance(result, (tuple, list)) and len(result) == 1
|
||||
and isinstance(result[0], dict) and "ui" in result[0]):
|
||||
return result[0]
|
||||
deserialized = await deserialize_from_isolation(result, extension)
|
||||
deserialized = _wrap_remote_handles_as_host_proxies(deserialized, extension)
|
||||
scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True)
|
||||
return _detach_shared_cpu_tensors(deserialized)
|
||||
except ImportError:
|
||||
return await extension.execute_node(node_name, **inputs)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s ISO:execute_error ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if prev_child is not None:
|
||||
os.environ["PYISOLATE_CHILD"] = prev_child
|
||||
logger.debug(
|
||||
"%s ISO:execute_end ext=%s node=%s uid=%s",
|
||||
LOG_PREFIX,
|
||||
extension.name,
|
||||
node_name,
|
||||
node_unique_id or "-",
|
||||
)
|
||||
scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True)
|
||||
|
||||
def _input_types(
|
||||
cls,
|
||||
include_hidden: bool = True,
|
||||
return_schema: bool = False,
|
||||
live_inputs: Any = None,
|
||||
):
|
||||
if not is_v3:
|
||||
return restored_input_types
|
||||
|
||||
inputs_copy = copy.deepcopy(restored_input_types)
|
||||
if not include_hidden:
|
||||
inputs_copy.pop("hidden", None)
|
||||
|
||||
v3_data: Dict[str, Any] = {"hidden_inputs": {}}
|
||||
dynamic = inputs_copy.pop("dynamic_paths", None)
|
||||
if dynamic is not None:
|
||||
v3_data["dynamic_paths"] = dynamic
|
||||
|
||||
if return_schema:
|
||||
hidden_vals = info.get("hidden", []) or []
|
||||
hidden_enums = []
|
||||
for h in hidden_vals:
|
||||
try:
|
||||
hidden_enums.append(latest_io.Hidden(h))
|
||||
except Exception:
|
||||
hidden_enums.append(h)
|
||||
|
||||
class SchemaProxy:
|
||||
hidden = hidden_enums
|
||||
|
||||
return inputs_copy, SchemaProxy, v3_data
|
||||
return inputs_copy
|
||||
|
||||
def _validate_class(cls):
|
||||
return True
|
||||
|
||||
def _get_node_info_v1(cls):
|
||||
node_info = copy.deepcopy(info.get("schema_v1", {}))
|
||||
relative_python_module = node_info.get("python_module")
|
||||
if not isinstance(relative_python_module, str) or not relative_python_module:
|
||||
relative_python_module = f"custom_nodes.{extension.name}"
|
||||
node_info["python_module"] = relative_python_module
|
||||
return node_info
|
||||
|
||||
def _get_base_class(cls):
|
||||
return latest_io.ComfyNode
|
||||
|
||||
attributes: Dict[str, object] = {
|
||||
"FUNCTION": function_name,
|
||||
"CATEGORY": info.get("category", ""),
|
||||
"OUTPUT_NODE": info.get("output_node", False),
|
||||
"RETURN_TYPES": tuple(info.get("return_types", ()) or ()),
|
||||
"RETURN_NAMES": info.get("return_names"),
|
||||
function_name: _execute,
|
||||
"_pyisolate_extension": extension,
|
||||
"_pyisolate_node_name": node_name,
|
||||
"INPUT_TYPES": classmethod(_input_types),
|
||||
}
|
||||
|
||||
output_is_list = info.get("output_is_list")
|
||||
if output_is_list is not None:
|
||||
attributes["OUTPUT_IS_LIST"] = tuple(output_is_list)
|
||||
|
||||
if is_v3:
|
||||
attributes["VALIDATE_CLASS"] = classmethod(_validate_class)
|
||||
attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1)
|
||||
attributes["GET_BASE_CLASS"] = classmethod(_get_base_class)
|
||||
attributes["DESCRIPTION"] = info.get("description", "")
|
||||
attributes["EXPERIMENTAL"] = info.get("experimental", False)
|
||||
attributes["DEPRECATED"] = info.get("deprecated", False)
|
||||
attributes["API_NODE"] = info.get("api_node", False)
|
||||
attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False)
|
||||
attributes["ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False)
|
||||
attributes["_ACCEPT_ALL_INPUTS"] = info.get("accept_all_inputs", False)
|
||||
attributes["INPUT_IS_LIST"] = info.get("input_is_list", False)
|
||||
|
||||
class_name = f"PyIsolate_{node_name}".replace(" ", "_")
|
||||
bases = (_ComfyNodeInternal,) if is_v3 else ()
|
||||
stub_cls = type(class_name, bases, attributes)
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
stub_cls.VALIDATE_CLASS()
|
||||
except Exception as e:
|
||||
logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e)
|
||||
|
||||
return stub_cls
|
||||
|
||||
|
||||
def get_class_types_for_extension(
|
||||
extension_name: str,
|
||||
running_extensions: Dict[str, "ComfyNodeExtension"],
|
||||
specs: List[Any],
|
||||
) -> Set[str]:
|
||||
extension = running_extensions.get(extension_name)
|
||||
if not extension:
|
||||
return set()
|
||||
|
||||
ext_path = Path(extension.module_path)
|
||||
class_types = set()
|
||||
for spec in specs:
|
||||
if spec.module_path.resolve() == ext_path.resolve():
|
||||
class_types.add(spec.node_name)
|
||||
return class_types
|
||||
|
||||
|
||||
__all__ = ["build_stub_class", "get_class_types_for_extension"]
|
||||
217
comfy/isolation/shm_forensics.py
Normal file
217
comfy/isolation/shm_forensics.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# pylint: disable=consider-using-from-import,import-outside-toplevel
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
LOG_PREFIX = "]["
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _shm_debug_enabled() -> bool:
|
||||
return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1"
|
||||
|
||||
|
||||
class _SHMForensicsTracker:
|
||||
def __init__(self) -> None:
|
||||
self._started = False
|
||||
self._tracked_files: Set[str] = set()
|
||||
self._current_model_context: Dict[str, str] = {
|
||||
"id": "unknown",
|
||||
"name": "unknown",
|
||||
"hash": "????",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _snapshot_shm() -> Set[str]:
|
||||
shm_path = Path("/dev/shm")
|
||||
if not shm_path.exists():
|
||||
return set()
|
||||
return {f.name for f in shm_path.glob("torch_*")}
|
||||
|
||||
def start(self) -> None:
|
||||
if self._started or not _shm_debug_enabled():
|
||||
return
|
||||
self._tracked_files = self._snapshot_shm()
|
||||
self._started = True
|
||||
logger.debug(
|
||||
"%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files)
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
self.scan("shutdown", refresh_model_context=True)
|
||||
self._started = False
|
||||
logger.debug("%s SHM:forensics_disabled", LOG_PREFIX)
|
||||
|
||||
def _compute_model_hash(self, model_patcher: Any) -> str:
|
||||
try:
|
||||
model_instance_id = getattr(model_patcher, "_instance_id", None)
|
||||
if model_instance_id is not None:
|
||||
model_id_text = str(model_instance_id)
|
||||
return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text
|
||||
|
||||
import torch
|
||||
|
||||
real_model = (
|
||||
model_patcher.model
|
||||
if hasattr(model_patcher, "model")
|
||||
else model_patcher
|
||||
)
|
||||
tensor = None
|
||||
if hasattr(real_model, "parameters"):
|
||||
for p in real_model.parameters():
|
||||
if torch.is_tensor(p) and p.numel() > 0:
|
||||
tensor = p
|
||||
break
|
||||
|
||||
if tensor is None:
|
||||
return "0000"
|
||||
|
||||
flat = tensor.flatten()
|
||||
values = []
|
||||
indices = [0, flat.shape[0] // 2, flat.shape[0] - 1]
|
||||
for i in indices:
|
||||
if i < flat.shape[0]:
|
||||
values.append(flat[i].item())
|
||||
|
||||
size = 0
|
||||
if hasattr(model_patcher, "model_size"):
|
||||
size = model_patcher.model_size()
|
||||
sample_str = f"{values}_{id(model_patcher):016x}_{size}"
|
||||
return hashlib.sha256(sample_str.encode()).hexdigest()[-4:]
|
||||
except Exception:
|
||||
return "err!"
|
||||
|
||||
def _get_models_snapshot(self) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
import comfy.model_management as model_management
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
snapshot: List[Dict[str, Any]] = []
|
||||
try:
|
||||
for loaded_model in model_management.current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is None:
|
||||
continue
|
||||
if str(getattr(loaded_model, "device", "")) != "cuda:0":
|
||||
continue
|
||||
|
||||
name = (
|
||||
model.model.__class__.__name__
|
||||
if hasattr(model, "model")
|
||||
else type(model).__name__
|
||||
)
|
||||
model_hash = self._compute_model_hash(model)
|
||||
model_instance_id = getattr(model, "_instance_id", None)
|
||||
if model_instance_id is None:
|
||||
model_instance_id = model_hash
|
||||
snapshot.append(
|
||||
{
|
||||
"name": str(name),
|
||||
"id": str(model_instance_id),
|
||||
"hash": str(model_hash or "????"),
|
||||
"used": bool(getattr(loaded_model, "currently_used", False)),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
return snapshot
|
||||
|
||||
def _update_model_context(self) -> None:
|
||||
snapshot = self._get_models_snapshot()
|
||||
selected = None
|
||||
|
||||
used_models = [m for m in snapshot if m.get("used") and m.get("id")]
|
||||
if used_models:
|
||||
selected = used_models[-1]
|
||||
else:
|
||||
live_models = [m for m in snapshot if m.get("id")]
|
||||
if live_models:
|
||||
selected = live_models[-1]
|
||||
|
||||
if selected is None:
|
||||
self._current_model_context = {
|
||||
"id": "unknown",
|
||||
"name": "unknown",
|
||||
"hash": "????",
|
||||
}
|
||||
return
|
||||
|
||||
self._current_model_context = {
|
||||
"id": str(selected.get("id", "unknown")),
|
||||
"name": str(selected.get("name", "unknown")),
|
||||
"hash": str(selected.get("hash", "????") or "????"),
|
||||
}
|
||||
|
||||
def scan(self, marker: str, refresh_model_context: bool = True) -> None:
|
||||
if not self._started or not _shm_debug_enabled():
|
||||
return
|
||||
|
||||
if refresh_model_context:
|
||||
self._update_model_context()
|
||||
|
||||
current = self._snapshot_shm()
|
||||
added = current - self._tracked_files
|
||||
removed = self._tracked_files - current
|
||||
self._tracked_files = current
|
||||
|
||||
if not added and not removed:
|
||||
logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker)
|
||||
return
|
||||
|
||||
for filename in sorted(added):
|
||||
logger.info("%s SHM:created | %s", LOG_PREFIX, filename)
|
||||
model_id = self._current_model_context["id"]
|
||||
if model_id == "unknown":
|
||||
logger.error(
|
||||
"%s SHM:model_association_missing | file=%s | reason=no_active_model_context",
|
||||
LOG_PREFIX,
|
||||
filename,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s",
|
||||
LOG_PREFIX,
|
||||
model_id,
|
||||
filename,
|
||||
self._current_model_context["name"],
|
||||
self._current_model_context["hash"],
|
||||
)
|
||||
|
||||
for filename in sorted(removed):
|
||||
logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename)
|
||||
|
||||
logger.debug(
|
||||
"%s SHM:scan marker=%s created=%d deleted=%d active=%d",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
len(added),
|
||||
len(removed),
|
||||
len(self._tracked_files),
|
||||
)
|
||||
|
||||
|
||||
_TRACKER = _SHMForensicsTracker()
|
||||
|
||||
|
||||
def start_shm_forensics() -> None:
|
||||
_TRACKER.start()
|
||||
|
||||
|
||||
def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None:
|
||||
_TRACKER.scan(marker, refresh_model_context=refresh_model_context)
|
||||
|
||||
|
||||
def stop_shm_forensics() -> None:
|
||||
_TRACKER.stop()
|
||||
|
||||
|
||||
atexit.register(stop_shm_forensics)
|
||||
214
comfy/isolation/vae_proxy.py
Normal file
214
comfy/isolation/vae_proxy.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
IS_CHILD_PROCESS,
|
||||
BaseProxy,
|
||||
BaseRegistry,
|
||||
detach_if_grad,
|
||||
)
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy, ModelPatcherRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FirstStageModelRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "first_stage_model"
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
obj = self._get_instance(instance_id)
|
||||
return getattr(obj, name)
|
||||
|
||||
async def has_property(self, instance_id: str, name: str) -> bool:
|
||||
obj = self._get_instance(instance_id)
|
||||
return hasattr(obj, name)
|
||||
|
||||
|
||||
class FirstStageModelProxy(BaseProxy[FirstStageModelRegistry]):
|
||||
_registry_class = FirstStageModelRegistry
|
||||
__module__ = "comfy.ldm.models.autoencoder"
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self._call_rpc("get_property", name)
|
||||
except Exception as e:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<FirstStageModelProxy {self._instance_id}>"
|
||||
|
||||
|
||||
class VAERegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "vae"
|
||||
|
||||
async def get_patcher_id(self, instance_id: str) -> str:
|
||||
vae = self._get_instance(instance_id)
|
||||
return ModelPatcherRegistry().register(vae.patcher)
|
||||
|
||||
async def get_first_stage_model_id(self, instance_id: str) -> str:
|
||||
vae = self._get_instance(instance_id)
|
||||
return FirstStageModelRegistry().register(vae.first_stage_model)
|
||||
|
||||
async def encode(self, instance_id: str, pixels: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).encode(pixels))
|
||||
|
||||
async def encode_tiled(
|
||||
self,
|
||||
instance_id: str,
|
||||
pixels: Any,
|
||||
tile_x: int = 512,
|
||||
tile_y: int = 512,
|
||||
overlap: int = 64,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).encode_tiled(
|
||||
pixels, tile_x=tile_x, tile_y=tile_y, overlap=overlap
|
||||
)
|
||||
)
|
||||
|
||||
async def decode(self, instance_id: str, samples: Any, **kwargs: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).decode(samples, **kwargs))
|
||||
|
||||
async def decode_tiled(
|
||||
self,
|
||||
instance_id: str,
|
||||
samples: Any,
|
||||
tile_x: int = 64,
|
||||
tile_y: int = 64,
|
||||
overlap: int = 16,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return detach_if_grad(
|
||||
self._get_instance(instance_id).decode_tiled(
|
||||
samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
async def get_property(self, instance_id: str, name: str) -> Any:
|
||||
return getattr(self._get_instance(instance_id), name)
|
||||
|
||||
async def memory_used_encode(self, instance_id: str, shape: Any, dtype: Any) -> int:
|
||||
return self._get_instance(instance_id).memory_used_encode(shape, dtype)
|
||||
|
||||
async def memory_used_decode(self, instance_id: str, shape: Any, dtype: Any) -> int:
|
||||
return self._get_instance(instance_id).memory_used_decode(shape, dtype)
|
||||
|
||||
async def process_input(self, instance_id: str, image: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).process_input(image))
|
||||
|
||||
async def process_output(self, instance_id: str, image: Any) -> Any:
|
||||
return detach_if_grad(self._get_instance(instance_id).process_output(image))
|
||||
|
||||
|
||||
class VAEProxy(BaseProxy[VAERegistry]):
|
||||
_registry_class = VAERegistry
|
||||
__module__ = "comfy.sd"
|
||||
|
||||
@property
|
||||
def patcher(self) -> ModelPatcherProxy:
|
||||
if not hasattr(self, "_patcher_proxy"):
|
||||
patcher_id = self._call_rpc("get_patcher_id")
|
||||
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
|
||||
return self._patcher_proxy
|
||||
|
||||
@property
|
||||
def first_stage_model(self) -> FirstStageModelProxy:
|
||||
if not hasattr(self, "_first_stage_model_proxy"):
|
||||
fsm_id = self._call_rpc("get_first_stage_model_id")
|
||||
self._first_stage_model_proxy = FirstStageModelProxy(
|
||||
fsm_id, manage_lifecycle=False
|
||||
)
|
||||
return self._first_stage_model_proxy
|
||||
|
||||
@property
|
||||
def vae_dtype(self) -> Any:
|
||||
return self._get_property("vae_dtype")
|
||||
|
||||
def encode(self, pixels: Any) -> Any:
|
||||
return self._call_rpc("encode", pixels)
|
||||
|
||||
def encode_tiled(
|
||||
self, pixels: Any, tile_x: int = 512, tile_y: int = 512, overlap: int = 64
|
||||
) -> Any:
|
||||
return self._call_rpc("encode_tiled", pixels, tile_x, tile_y, overlap)
|
||||
|
||||
def decode(self, samples: Any, **kwargs: Any) -> Any:
|
||||
return self._call_rpc("decode", samples, **kwargs)
|
||||
|
||||
def decode_tiled(
|
||||
self,
|
||||
samples: Any,
|
||||
tile_x: int = 64,
|
||||
tile_y: int = 64,
|
||||
overlap: int = 16,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self._call_rpc(
|
||||
"decode_tiled", samples, tile_x, tile_y, overlap, **kwargs
|
||||
)
|
||||
|
||||
def get_sd(self) -> Any:
|
||||
return self._call_rpc("get_sd")
|
||||
|
||||
def _get_property(self, name: str) -> Any:
|
||||
return self._call_rpc("get_property", name)
|
||||
|
||||
@property
|
||||
def latent_dim(self) -> int:
|
||||
return self._get_property("latent_dim")
|
||||
|
||||
@property
|
||||
def latent_channels(self) -> int:
|
||||
return self._get_property("latent_channels")
|
||||
|
||||
@property
|
||||
def downscale_ratio(self) -> Any:
|
||||
return self._get_property("downscale_ratio")
|
||||
|
||||
@property
|
||||
def upscale_ratio(self) -> Any:
|
||||
return self._get_property("upscale_ratio")
|
||||
|
||||
@property
|
||||
def output_channels(self) -> int:
|
||||
return self._get_property("output_channels")
|
||||
|
||||
@property
|
||||
def check_not_vide(self) -> bool:
|
||||
return self._get_property("not_video")
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self._get_property("device")
|
||||
|
||||
@property
|
||||
def working_dtypes(self) -> Any:
|
||||
return self._get_property("working_dtypes")
|
||||
|
||||
@property
|
||||
def disable_offload(self) -> bool:
|
||||
return self._get_property("disable_offload")
|
||||
|
||||
@property
|
||||
def size(self) -> Any:
|
||||
return self._get_property("size")
|
||||
|
||||
def memory_used_encode(self, shape: Any, dtype: Any) -> int:
|
||||
return self._call_rpc("memory_used_encode", shape, dtype)
|
||||
|
||||
def memory_used_decode(self, shape: Any, dtype: Any) -> int:
|
||||
return self._call_rpc("memory_used_decode", shape, dtype)
|
||||
|
||||
def process_input(self, image: Any) -> Any:
|
||||
return self._call_rpc("process_input", image)
|
||||
|
||||
def process_output(self, image: Any) -> Any:
|
||||
return self._call_rpc("process_output", image)
|
||||
|
||||
|
||||
if not IS_CHILD_PROCESS:
|
||||
_VAE_REGISTRY_SINGLETON = VAERegistry()
|
||||
_FIRST_STAGE_MODEL_REGISTRY_SINGLETON = FirstStageModelRegistry()
|
||||
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
@@ -12,8 +13,8 @@ from . import deis
|
||||
from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
from comfy.cli_args import args
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
@@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if isolation_active:
|
||||
target_device = sigmas.device
|
||||
if x.device != target_device:
|
||||
x = x.to(target_device)
|
||||
s_in = s_in.to(target_device)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
if s_churn > 0:
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
|
||||
@@ -44,22 +44,6 @@ class FluxParams:
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
def invert_slices(slices, length):
|
||||
sorted_slices = sorted(slices)
|
||||
result = []
|
||||
current = 0
|
||||
|
||||
for start, end in sorted_slices:
|
||||
if current < start:
|
||||
result.append((current, start))
|
||||
current = max(current, end)
|
||||
|
||||
if current < length:
|
||||
result.append((current, length))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@@ -154,7 +138,6 @@ class Flux(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control = None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
@@ -181,6 +164,10 @@ class Flux(nn.Module):
|
||||
txt = self.txt_norm(txt)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||
@@ -195,24 +182,6 @@ class Flux(nn.Module):
|
||||
else:
|
||||
pe = None
|
||||
|
||||
vec_orig = vec
|
||||
txt_vec = vec
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
modulation_dims = []
|
||||
batch = vec.shape[0] // 2
|
||||
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
|
||||
invert = invert_slices(timestep_zero_index, img.shape[1])
|
||||
for s in invert:
|
||||
modulation_dims.append((s[0], s[1], 0))
|
||||
for s in timestep_zero_index:
|
||||
modulation_dims.append((s[0], s[1], 1))
|
||||
extra_kwargs["modulation_dims_img"] = modulation_dims
|
||||
txt_vec = vec[:batch]
|
||||
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
@@ -226,8 +195,7 @@ class Flux(nn.Module):
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
**extra_kwargs)
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
@@ -245,8 +213,7 @@ class Flux(nn.Module):
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options,
|
||||
**extra_kwargs)
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -263,12 +230,6 @@ class Flux(nn.Module):
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
lambda a: 0 if a == 0 else a + txt.shape[1]
|
||||
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
|
||||
extra_kwargs["modulation_dims"] = modulation_dims_combined
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
@@ -281,8 +242,7 @@ class Flux(nn.Module):
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
**extra_kwargs)
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
@@ -293,7 +253,7 @@ class Flux(nn.Module):
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -304,11 +264,7 @@ class Flux(nn.Module):
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
extra_kwargs["modulation_dims"] = modulation_dims
|
||||
|
||||
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
@@ -356,16 +312,13 @@ class Flux(nn.Module):
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||
img_tokens = img.shape[1]
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
ref_num_tokens = []
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||
timestep_zero = ref_latents_method == "index_timestep_zero"
|
||||
for ref in ref_latents:
|
||||
if ref_latents_method in ("index", "index_timestep_zero"):
|
||||
if ref_latents_method == "index":
|
||||
index += self.params.ref_index_scale
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
@@ -389,13 +342,6 @@ class Flux(nn.Module):
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
@@ -403,6 +349,6 @@ class Flux(nn.Module):
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
|
||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
import comfy.ldm.lightricks.av_model
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
@@ -112,8 +113,20 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
|
||||
from comfy.cli_args import args
|
||||
isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
if isolation_runtime_enabled:
|
||||
def __reduce__(self):
|
||||
"""Ensure pickling yields a proxy instead of failing on local class."""
|
||||
try:
|
||||
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
|
||||
registry = ModelSamplingRegistry()
|
||||
ms_id = registry.register(self)
|
||||
return (ModelSamplingProxy, (ms_id,))
|
||||
except Exception as exc:
|
||||
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
|
||||
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
@@ -372,7 +372,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
@@ -400,7 +400,7 @@ try:
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||
@@ -497,6 +497,9 @@ except:
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
def _isolation_mode_enabled():
|
||||
return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
@@ -576,8 +579,9 @@ class LoadedModel:
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
if self.model_finalizer is not None:
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
@@ -591,8 +595,15 @@ class LoadedModel:
|
||||
if self._patcher_finalizer is not None:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def dead_state(self):
|
||||
model_ref_gone = self.model is None
|
||||
real_model_ref = self.real_model
|
||||
real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None
|
||||
return model_ref_gone, real_model_ref_gone
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
model_ref_gone, real_model_ref_gone = self.dead_state()
|
||||
return model_ref_gone or real_model_ref_gone
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
@@ -638,6 +649,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
isolation_active = _isolation_mode_enabled()
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
@@ -646,6 +658,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
if can_unload and isolation_active:
|
||||
try:
|
||||
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
flush_tensor_keeper = None
|
||||
if callable(flush_tensor_keeper):
|
||||
flushed = flush_tensor_keeper()
|
||||
if flushed > 0:
|
||||
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
|
||||
gc.collect()
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
@@ -666,7 +689,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
unloaded = current_loaded_models.pop(i)
|
||||
model_obj = unloaded.model
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
unloaded_models.append(unloaded)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
@@ -725,7 +754,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
for i in to_unload:
|
||||
model_to_unload = current_loaded_models.pop(i)
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
if model_to_unload.model_finalizer is not None:
|
||||
model_to_unload.model_finalizer.detach()
|
||||
model_to_unload.model_finalizer = None
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
@@ -788,25 +819,62 @@ def loaded_models(only_currently_used=False):
|
||||
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
if not _isolation_mode_enabled():
|
||||
dead_found = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].is_dead():
|
||||
dead_found = True
|
||||
break
|
||||
|
||||
if dead_found:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
leaked = current_loaded_models.pop(i)
|
||||
model_obj = getattr(leaked, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
return
|
||||
|
||||
dead_found = False
|
||||
has_real_model_leak = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
do_gc = True
|
||||
break
|
||||
model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state()
|
||||
if model_ref_gone or real_model_ref_gone:
|
||||
dead_found = True
|
||||
if real_model_ref_gone and not model_ref_gone:
|
||||
has_real_model_leak = True
|
||||
|
||||
if do_gc:
|
||||
if dead_found:
|
||||
if has_real_model_leak:
|
||||
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||
else:
|
||||
logging.debug("Cleaning stale loaded-model entries with released patcher references.")
|
||||
gc.collect()
|
||||
soft_empty_cache()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
model_ref_gone, real_model_ref_gone = cur.dead_state()
|
||||
if model_ref_gone or real_model_ref_gone:
|
||||
if real_model_ref_gone and not model_ref_gone:
|
||||
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||
else:
|
||||
logging.debug("Cleaning stale loaded-model entry with released patcher reference.")
|
||||
leaked = current_loaded_models.pop(i)
|
||||
model_obj = getattr(leaked, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
|
||||
|
||||
def archive_model_dtypes(model):
|
||||
@@ -820,11 +888,20 @@ def archive_model_dtypes(model):
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
real_model_ref = current_loaded_models[i].real_model
|
||||
if real_model_ref is None:
|
||||
to_delete = [i] + to_delete
|
||||
continue
|
||||
if callable(real_model_ref) and real_model_ref() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
x = current_loaded_models.pop(i)
|
||||
model_obj = getattr(x, "model", None)
|
||||
if model_obj is not None:
|
||||
cleanup = getattr(model_obj, "cleanup", None)
|
||||
if callable(cleanup):
|
||||
cleanup()
|
||||
del x
|
||||
|
||||
def dtype_size(dtype):
|
||||
|
||||
@@ -11,12 +11,14 @@ from functools import partial
|
||||
import collections
|
||||
import math
|
||||
import logging
|
||||
import os
|
||||
import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.utils
|
||||
from comfy.cli_args import args
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -210,9 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
result = executor.execute(model, conds, x_in, timestep, model_options)
|
||||
return result
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@@ -269,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
|
||||
if memory_required * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
@@ -294,9 +299,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
if isolation_active:
|
||||
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
|
||||
input_x = torch.cat(input_x).to(target_device)
|
||||
else:
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
if isolation_active:
|
||||
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
|
||||
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
|
||||
else:
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
@@ -327,9 +340,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
out_t = output[o]
|
||||
mult_t = mult[o]
|
||||
if isolation_active:
|
||||
target_dev = out_conds[cond_index].device
|
||||
if hasattr(out_t, "device") and out_t.device != target_dev:
|
||||
out_t = out_t.to(target_dev)
|
||||
if hasattr(mult_t, "device") and mult_t.device != target_dev:
|
||||
mult_t = mult_t.to(target_dev)
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
out_conds[cond_index] += out_t * mult_t
|
||||
out_counts[cond_index] += mult_t
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
@@ -337,8 +358,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
out_c += out_t * mult_t
|
||||
out_cts += mult_t
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
@@ -392,14 +413,31 @@ class KSamplerX0Inpaint:
|
||||
self.inner_model = model
|
||||
self.sigmas = sigmas
|
||||
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
||||
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
if denoise_mask is not None:
|
||||
if isolation_active and denoise_mask.device != x.device:
|
||||
denoise_mask = denoise_mask.to(x.device)
|
||||
if "denoise_mask_function" in model_options:
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
||||
if isolation_active:
|
||||
latent_image = self.latent_image
|
||||
if hasattr(latent_image, "device") and latent_image.device != x.device:
|
||||
latent_image = latent_image.to(x.device)
|
||||
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
|
||||
if hasattr(scaled, "device") and scaled.device != x.device:
|
||||
scaled = scaled.to(x.device)
|
||||
else:
|
||||
scaled = self.inner_model.inner_model.scale_latent_inpaint(
|
||||
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
|
||||
)
|
||||
x = x * denoise_mask + scaled * latent_mask
|
||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out = out * denoise_mask + self.latent_image * latent_mask
|
||||
latent_image = self.latent_image
|
||||
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
|
||||
latent_image = latent_image.to(out.device)
|
||||
out = out * denoise_mask + latent_image * latent_mask
|
||||
return out
|
||||
|
||||
def simple_scheduler(model_sampling, steps):
|
||||
@@ -741,7 +779,11 @@ class KSAMPLER(Sampler):
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
|
||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||
model_sampling = model_wrap.inner_model.model_sampling
|
||||
noise = model_sampling.noise_scaling(
|
||||
sigmas[0], noise, latent_image, max_denoise
|
||||
)
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
|
||||
@@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
|
||||
has_first_frame = False
|
||||
for frame in frames:
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
|
||||
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||
if to_skip < frame.samples:
|
||||
has_first_frame = True
|
||||
break
|
||||
@@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
|
||||
for frame in frames:
|
||||
if self.__duration and frame.time > start_time + self.__duration:
|
||||
if frame.time > start_time + self.__duration:
|
||||
break
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
|
||||
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D, PLY as _PLY, NPZ as _NPZ
|
||||
|
||||
|
||||
class FolderType(str, Enum):
|
||||
@@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
|
||||
'''Float input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||
display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
|
||||
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.min = min
|
||||
@@ -678,6 +678,16 @@ class Mesh(ComfyTypeIO):
|
||||
Type = MESH
|
||||
|
||||
|
||||
@comfytype(io_type="PLY")
|
||||
class Ply(ComfyTypeIO):
|
||||
Type = _PLY
|
||||
|
||||
|
||||
@comfytype(io_type="NPZ")
|
||||
class Npz(ComfyTypeIO):
|
||||
Type = _NPZ
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D")
|
||||
class File3DAny(ComfyTypeIO):
|
||||
"""General 3D file type - accepts any supported 3D format."""
|
||||
@@ -2197,6 +2207,8 @@ __all__ = [
|
||||
"LossMap",
|
||||
"Voxel",
|
||||
"Mesh",
|
||||
"Ply",
|
||||
"Npz",
|
||||
"File3DAny",
|
||||
"File3DGLB",
|
||||
"File3DGLTF",
|
||||
|
||||
@@ -65,6 +65,22 @@ class SavedAudios(_UIOutput):
|
||||
return {"audio": self.results}
|
||||
|
||||
|
||||
def _is_isolated_child() -> bool:
|
||||
return os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
|
||||
def _get_preview_folder_type() -> FolderType:
|
||||
if _is_isolated_child():
|
||||
return FolderType.output
|
||||
return FolderType.temp
|
||||
|
||||
|
||||
def _get_preview_route_prefix(folder_type: FolderType) -> str:
|
||||
if folder_type == FolderType.output:
|
||||
return "output"
|
||||
return "temp"
|
||||
|
||||
|
||||
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
|
||||
if folder_type == FolderType.input:
|
||||
return folder_paths.get_input_directory()
|
||||
@@ -388,10 +404,11 @@ class AudioSaveHelper:
|
||||
|
||||
class PreviewImage(_UIOutput):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
|
||||
folder_type = _get_preview_folder_type()
|
||||
self.values = ImageSaveHelper.save_images(
|
||||
image,
|
||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
folder_type=folder_type,
|
||||
cls=cls,
|
||||
compress_level=1,
|
||||
)
|
||||
@@ -412,10 +429,11 @@ class PreviewMask(PreviewImage):
|
||||
|
||||
class PreviewAudio(_UIOutput):
|
||||
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
|
||||
folder_type = _get_preview_folder_type()
|
||||
self.values = AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
folder_type=folder_type,
|
||||
cls=cls,
|
||||
format="flac",
|
||||
quality="128k",
|
||||
@@ -438,15 +456,16 @@ class PreviewUI3D(_UIOutput):
|
||||
self.model_file = model_file
|
||||
self.camera_info = camera_info
|
||||
self.bg_image_path = None
|
||||
folder_type = _get_preview_folder_type()
|
||||
bg_image = kwargs.get("bg_image", None)
|
||||
if bg_image is not None:
|
||||
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
|
||||
img = PILImage.fromarray(img_array)
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
preview_dir = _get_directory_by_folder_type(folder_type)
|
||||
filename = f"bg_{uuid.uuid4().hex}.png"
|
||||
bg_image_path = os.path.join(temp_dir, filename)
|
||||
bg_image_path = os.path.join(preview_dir, filename)
|
||||
img.save(bg_image_path, compress_level=1)
|
||||
self.bg_image_path = f"temp/{filename}"
|
||||
self.bg_image_path = f"{_get_preview_route_prefix(folder_type)}/{filename}"
|
||||
|
||||
def as_dict(self):
|
||||
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH, File3D
|
||||
from .image_types import SVG
|
||||
from .ply_types import PLY
|
||||
from .npz_types import NPZ
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
@@ -11,4 +13,6 @@ __all__ = [
|
||||
"MESH",
|
||||
"File3D",
|
||||
"SVG",
|
||||
"PLY",
|
||||
"NPZ",
|
||||
]
|
||||
|
||||
27
comfy_api/latest/_util/npz_types.py
Normal file
27
comfy_api/latest/_util/npz_types.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class NPZ:
|
||||
"""Ordered collection of NPZ file payloads.
|
||||
|
||||
Each entry in ``frames`` is a complete compressed ``.npz`` file stored
|
||||
as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO).
|
||||
``save_to`` writes numbered files into a directory.
|
||||
"""
|
||||
|
||||
def __init__(self, frames: list[bytes]) -> None:
|
||||
self.frames = frames
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return len(self.frames)
|
||||
|
||||
def save_to(self, directory: str, prefix: str = "frame") -> str:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
for i, frame_bytes in enumerate(self.frames):
|
||||
path = os.path.join(directory, f"{prefix}_{i:06d}.npz")
|
||||
with open(path, "wb") as f:
|
||||
f.write(frame_bytes)
|
||||
return directory
|
||||
97
comfy_api/latest/_util/ply_types.py
Normal file
97
comfy_api/latest/_util/ply_types.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PLY:
|
||||
"""Point cloud payload for PLY file output.
|
||||
|
||||
Supports two schemas:
|
||||
- Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format)
|
||||
- Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format)
|
||||
|
||||
When ``raw_data`` is provided, the object acts as an opaque binary PLY
|
||||
carrier and ``save_to`` writes the bytes directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
points: np.ndarray | None = None,
|
||||
colors: np.ndarray | None = None,
|
||||
confidence: np.ndarray | None = None,
|
||||
view_id: np.ndarray | None = None,
|
||||
raw_data: bytes | None = None,
|
||||
) -> None:
|
||||
self.raw_data = raw_data
|
||||
if raw_data is not None:
|
||||
self.points = None
|
||||
self.colors = None
|
||||
self.confidence = None
|
||||
self.view_id = None
|
||||
return
|
||||
if points is None:
|
||||
raise ValueError("Either points or raw_data must be provided")
|
||||
if points.ndim != 2 or points.shape[1] != 3:
|
||||
raise ValueError(f"points must be (N, 3), got {points.shape}")
|
||||
self.points = np.ascontiguousarray(points, dtype=np.float32)
|
||||
self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None
|
||||
self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None
|
||||
self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None
|
||||
|
||||
@property
|
||||
def is_gaussian(self) -> bool:
|
||||
return self.raw_data is not None
|
||||
|
||||
@property
|
||||
def num_points(self) -> int:
|
||||
if self.points is not None:
|
||||
return self.points.shape[0]
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _to_numpy(arr, dtype):
|
||||
if arr is None:
|
||||
return None
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
|
||||
return np.ascontiguousarray(arr, dtype=dtype)
|
||||
|
||||
def save_to(self, path: str) -> str:
|
||||
if self.raw_data is not None:
|
||||
with open(path, "wb") as f:
|
||||
f.write(self.raw_data)
|
||||
return path
|
||||
self.points = self._to_numpy(self.points, np.float32)
|
||||
self.colors = self._to_numpy(self.colors, np.float32)
|
||||
self.confidence = self._to_numpy(self.confidence, np.float32)
|
||||
self.view_id = self._to_numpy(self.view_id, np.int32)
|
||||
N = self.num_points
|
||||
header_lines = [
|
||||
"ply",
|
||||
"format ascii 1.0",
|
||||
f"element vertex {N}",
|
||||
"property float x",
|
||||
"property float y",
|
||||
"property float z",
|
||||
]
|
||||
if self.colors is not None:
|
||||
header_lines += ["property uchar red", "property uchar green", "property uchar blue"]
|
||||
if self.confidence is not None:
|
||||
header_lines.append("property float confidence")
|
||||
if self.view_id is not None:
|
||||
header_lines.append("property int view_id")
|
||||
header_lines.append("end_header")
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(header_lines) + "\n")
|
||||
for i in range(N):
|
||||
parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"]
|
||||
if self.colors is not None:
|
||||
r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8)
|
||||
parts.append(f"{r} {g} {b}")
|
||||
if self.confidence is not None:
|
||||
parts.append(f"{self.confidence[i]}")
|
||||
if self.view_id is not None:
|
||||
parts.append(f"{int(self.view_id[i])}")
|
||||
f.write(" ".join(parts) + "\n")
|
||||
return path
|
||||
259
comfy_api/latest/_util/trimesh_types.py
Normal file
259
comfy_api/latest/_util/trimesh_types.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TrimeshData:
|
||||
"""Triangular mesh payload for cross-process transfer.
|
||||
|
||||
Lightweight carrier for mesh geometry that does not depend on the
|
||||
``trimesh`` library. Serializers create this on the host side;
|
||||
isolated child processes convert to/from ``trimesh.Trimesh`` as needed.
|
||||
|
||||
Supports both ColorVisuals (vertex_colors) and TextureVisuals
|
||||
(uv + material with textures).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vertices: np.ndarray,
|
||||
faces: np.ndarray,
|
||||
vertex_normals: np.ndarray | None = None,
|
||||
face_normals: np.ndarray | None = None,
|
||||
vertex_colors: np.ndarray | None = None,
|
||||
uv: np.ndarray | None = None,
|
||||
material: dict | None = None,
|
||||
vertex_attributes: dict | None = None,
|
||||
face_attributes: dict | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
self.vertices = np.ascontiguousarray(vertices, dtype=np.float64)
|
||||
self.faces = np.ascontiguousarray(faces, dtype=np.int64)
|
||||
self.vertex_normals = (
|
||||
np.ascontiguousarray(vertex_normals, dtype=np.float64)
|
||||
if vertex_normals is not None
|
||||
else None
|
||||
)
|
||||
self.face_normals = (
|
||||
np.ascontiguousarray(face_normals, dtype=np.float64)
|
||||
if face_normals is not None
|
||||
else None
|
||||
)
|
||||
self.vertex_colors = (
|
||||
np.ascontiguousarray(vertex_colors, dtype=np.uint8)
|
||||
if vertex_colors is not None
|
||||
else None
|
||||
)
|
||||
self.uv = (
|
||||
np.ascontiguousarray(uv, dtype=np.float64)
|
||||
if uv is not None
|
||||
else None
|
||||
)
|
||||
self.material = material
|
||||
self.vertex_attributes = vertex_attributes or {}
|
||||
self.face_attributes = face_attributes or {}
|
||||
self.metadata = self._detensorize_dict(metadata) if metadata else {}
|
||||
|
||||
@staticmethod
|
||||
def _detensorize_dict(d):
|
||||
"""Recursively convert any tensors in a dict back to numpy arrays."""
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if hasattr(v, "numpy"):
|
||||
result[k] = v.cpu().numpy() if hasattr(v, "cpu") else v.numpy()
|
||||
elif isinstance(v, dict):
|
||||
result[k] = TrimeshData._detensorize_dict(v)
|
||||
elif isinstance(v, list):
|
||||
result[k] = [
|
||||
item.cpu().numpy() if hasattr(item, "numpy") and hasattr(item, "cpu")
|
||||
else item.numpy() if hasattr(item, "numpy")
|
||||
else item
|
||||
for item in v
|
||||
]
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _to_numpy(arr, dtype):
|
||||
if arr is None:
|
||||
return None
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
|
||||
return np.ascontiguousarray(arr, dtype=dtype)
|
||||
|
||||
@property
|
||||
def num_vertices(self) -> int:
|
||||
return self.vertices.shape[0]
|
||||
|
||||
@property
|
||||
def num_faces(self) -> int:
|
||||
return self.faces.shape[0]
|
||||
|
||||
@property
|
||||
def has_texture(self) -> bool:
|
||||
return self.uv is not None and self.material is not None
|
||||
|
||||
def to_trimesh(self):
|
||||
"""Convert to trimesh.Trimesh (requires trimesh in the environment)."""
|
||||
import trimesh
|
||||
from trimesh.visual import TextureVisuals
|
||||
|
||||
kwargs = {}
|
||||
if self.vertex_normals is not None:
|
||||
kwargs["vertex_normals"] = self.vertex_normals
|
||||
if self.face_normals is not None:
|
||||
kwargs["face_normals"] = self.face_normals
|
||||
if self.metadata:
|
||||
kwargs["metadata"] = self.metadata
|
||||
|
||||
mesh = trimesh.Trimesh(
|
||||
vertices=self.vertices, faces=self.faces, process=False, **kwargs
|
||||
)
|
||||
|
||||
# Reconstruct visual
|
||||
if self.has_texture:
|
||||
material = self._dict_to_material(self.material)
|
||||
mesh.visual = TextureVisuals(uv=self.uv, material=material)
|
||||
elif self.vertex_colors is not None:
|
||||
mesh.visual.vertex_colors = self.vertex_colors
|
||||
|
||||
for k, v in self.vertex_attributes.items():
|
||||
mesh.vertex_attributes[k] = v
|
||||
|
||||
for k, v in self.face_attributes.items():
|
||||
mesh.face_attributes[k] = v
|
||||
|
||||
return mesh
|
||||
|
||||
@staticmethod
|
||||
def _material_to_dict(material) -> dict:
|
||||
"""Serialize a trimesh material to a plain dict."""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from trimesh.visual.material import PBRMaterial, SimpleMaterial
|
||||
|
||||
result = {"type": type(material).__name__, "name": getattr(material, "name", None)}
|
||||
|
||||
if isinstance(material, PBRMaterial):
|
||||
result["baseColorFactor"] = material.baseColorFactor
|
||||
result["metallicFactor"] = material.metallicFactor
|
||||
result["roughnessFactor"] = material.roughnessFactor
|
||||
result["emissiveFactor"] = material.emissiveFactor
|
||||
result["alphaMode"] = material.alphaMode
|
||||
result["alphaCutoff"] = material.alphaCutoff
|
||||
result["doubleSided"] = material.doubleSided
|
||||
|
||||
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
|
||||
"metallicRoughnessTexture", "occlusionTexture"):
|
||||
tex = getattr(material, tex_name, None)
|
||||
if tex is not None:
|
||||
buf = BytesIO()
|
||||
tex.save(buf, format="PNG")
|
||||
result[tex_name] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
elif isinstance(material, SimpleMaterial):
|
||||
result["main_color"] = list(material.main_color) if material.main_color is not None else None
|
||||
result["glossiness"] = material.glossiness
|
||||
if hasattr(material, "image") and material.image is not None:
|
||||
buf = BytesIO()
|
||||
material.image.save(buf, format="PNG")
|
||||
result["image"] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_material(d: dict):
|
||||
"""Reconstruct a trimesh material from a plain dict."""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from trimesh.visual.material import PBRMaterial, SimpleMaterial
|
||||
|
||||
mat_type = d.get("type", "PBRMaterial")
|
||||
|
||||
if mat_type == "PBRMaterial":
|
||||
kwargs = {
|
||||
"name": d.get("name"),
|
||||
"baseColorFactor": d.get("baseColorFactor"),
|
||||
"metallicFactor": d.get("metallicFactor"),
|
||||
"roughnessFactor": d.get("roughnessFactor"),
|
||||
"emissiveFactor": d.get("emissiveFactor"),
|
||||
"alphaMode": d.get("alphaMode"),
|
||||
"alphaCutoff": d.get("alphaCutoff"),
|
||||
"doubleSided": d.get("doubleSided"),
|
||||
}
|
||||
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
|
||||
"metallicRoughnessTexture", "occlusionTexture"):
|
||||
if tex_name in d and d[tex_name] is not None:
|
||||
img = Image.open(BytesIO(base64.b64decode(d[tex_name])))
|
||||
kwargs[tex_name] = img
|
||||
return PBRMaterial(**{k: v for k, v in kwargs.items() if v is not None})
|
||||
|
||||
elif mat_type == "SimpleMaterial":
|
||||
kwargs = {
|
||||
"name": d.get("name"),
|
||||
"glossiness": d.get("glossiness"),
|
||||
}
|
||||
if d.get("main_color") is not None:
|
||||
kwargs["diffuse"] = d["main_color"]
|
||||
if d.get("image") is not None:
|
||||
kwargs["image"] = Image.open(BytesIO(base64.b64decode(d["image"])))
|
||||
return SimpleMaterial(**kwargs)
|
||||
|
||||
raise ValueError(f"Unknown material type: {mat_type}")
|
||||
|
||||
@classmethod
|
||||
def from_trimesh(cls, mesh) -> TrimeshData:
|
||||
"""Create from a trimesh.Trimesh object."""
|
||||
from trimesh.visual.texture import TextureVisuals
|
||||
|
||||
vertex_normals = None
|
||||
if mesh._cache.cache.get("vertex_normals") is not None:
|
||||
vertex_normals = np.asarray(mesh.vertex_normals)
|
||||
|
||||
face_normals = None
|
||||
if mesh._cache.cache.get("face_normals") is not None:
|
||||
face_normals = np.asarray(mesh.face_normals)
|
||||
|
||||
vertex_colors = None
|
||||
uv = None
|
||||
material = None
|
||||
|
||||
if isinstance(mesh.visual, TextureVisuals):
|
||||
if mesh.visual.uv is not None:
|
||||
uv = np.asarray(mesh.visual.uv, dtype=np.float64)
|
||||
if mesh.visual.material is not None:
|
||||
material = cls._material_to_dict(mesh.visual.material)
|
||||
else:
|
||||
try:
|
||||
vc = mesh.visual.vertex_colors
|
||||
if vc is not None and len(vc) > 0:
|
||||
vertex_colors = np.asarray(vc, dtype=np.uint8)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
va = {}
|
||||
if hasattr(mesh, "vertex_attributes") and mesh.vertex_attributes:
|
||||
for k, v in mesh.vertex_attributes.items():
|
||||
va[k] = np.asarray(v) if hasattr(v, "__array__") else v
|
||||
|
||||
fa = {}
|
||||
if hasattr(mesh, "face_attributes") and mesh.face_attributes:
|
||||
for k, v in mesh.face_attributes.items():
|
||||
fa[k] = np.asarray(v) if hasattr(v, "__array__") else v
|
||||
|
||||
return cls(
|
||||
vertices=np.asarray(mesh.vertices),
|
||||
faces=np.asarray(mesh.faces),
|
||||
vertex_normals=vertex_normals,
|
||||
face_normals=face_normals,
|
||||
vertex_colors=vertex_colors,
|
||||
uv=uv,
|
||||
material=material,
|
||||
vertex_attributes=va if va else None,
|
||||
face_attributes=fa if fa else None,
|
||||
metadata=mesh.metadata if mesh.metadata else None,
|
||||
)
|
||||
18
comfy_api_sealed_worker/__init__.py
Normal file
18
comfy_api_sealed_worker/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""comfy_api_sealed_worker — torch-free type definitions for sealed worker children.
|
||||
|
||||
Drop-in replacement for comfy_api.latest._util type imports in sealed workers
|
||||
that do not have torch installed. Contains only data type definitions (TrimeshData,
|
||||
PLY, NPZ, etc.) with numpy-only dependencies.
|
||||
|
||||
Usage in serializers:
|
||||
if _IMPORT_TORCH:
|
||||
from comfy_api.latest._util.trimesh_types import TrimeshData
|
||||
else:
|
||||
from comfy_api_sealed_worker.trimesh_types import TrimeshData
|
||||
"""
|
||||
|
||||
from .trimesh_types import TrimeshData
|
||||
from .ply_types import PLY
|
||||
from .npz_types import NPZ
|
||||
|
||||
__all__ = ["TrimeshData", "PLY", "NPZ"]
|
||||
27
comfy_api_sealed_worker/npz_types.py
Normal file
27
comfy_api_sealed_worker/npz_types.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class NPZ:
|
||||
"""Ordered collection of NPZ file payloads.
|
||||
|
||||
Each entry in ``frames`` is a complete compressed ``.npz`` file stored
|
||||
as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO).
|
||||
``save_to`` writes numbered files into a directory.
|
||||
"""
|
||||
|
||||
def __init__(self, frames: list[bytes]) -> None:
|
||||
self.frames = frames
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return len(self.frames)
|
||||
|
||||
def save_to(self, directory: str, prefix: str = "frame") -> str:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
for i, frame_bytes in enumerate(self.frames):
|
||||
path = os.path.join(directory, f"{prefix}_{i:06d}.npz")
|
||||
with open(path, "wb") as f:
|
||||
f.write(frame_bytes)
|
||||
return directory
|
||||
97
comfy_api_sealed_worker/ply_types.py
Normal file
97
comfy_api_sealed_worker/ply_types.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PLY:
|
||||
"""Point cloud payload for PLY file output.
|
||||
|
||||
Supports two schemas:
|
||||
- Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format)
|
||||
- Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format)
|
||||
|
||||
When ``raw_data`` is provided, the object acts as an opaque binary PLY
|
||||
carrier and ``save_to`` writes the bytes directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
points: np.ndarray | None = None,
|
||||
colors: np.ndarray | None = None,
|
||||
confidence: np.ndarray | None = None,
|
||||
view_id: np.ndarray | None = None,
|
||||
raw_data: bytes | None = None,
|
||||
) -> None:
|
||||
self.raw_data = raw_data
|
||||
if raw_data is not None:
|
||||
self.points = None
|
||||
self.colors = None
|
||||
self.confidence = None
|
||||
self.view_id = None
|
||||
return
|
||||
if points is None:
|
||||
raise ValueError("Either points or raw_data must be provided")
|
||||
if points.ndim != 2 or points.shape[1] != 3:
|
||||
raise ValueError(f"points must be (N, 3), got {points.shape}")
|
||||
self.points = np.ascontiguousarray(points, dtype=np.float32)
|
||||
self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None
|
||||
self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None
|
||||
self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None
|
||||
|
||||
@property
|
||||
def is_gaussian(self) -> bool:
|
||||
return self.raw_data is not None
|
||||
|
||||
@property
|
||||
def num_points(self) -> int:
|
||||
if self.points is not None:
|
||||
return self.points.shape[0]
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _to_numpy(arr, dtype):
|
||||
if arr is None:
|
||||
return None
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
|
||||
return np.ascontiguousarray(arr, dtype=dtype)
|
||||
|
||||
def save_to(self, path: str) -> str:
|
||||
if self.raw_data is not None:
|
||||
with open(path, "wb") as f:
|
||||
f.write(self.raw_data)
|
||||
return path
|
||||
self.points = self._to_numpy(self.points, np.float32)
|
||||
self.colors = self._to_numpy(self.colors, np.float32)
|
||||
self.confidence = self._to_numpy(self.confidence, np.float32)
|
||||
self.view_id = self._to_numpy(self.view_id, np.int32)
|
||||
N = self.num_points
|
||||
header_lines = [
|
||||
"ply",
|
||||
"format ascii 1.0",
|
||||
f"element vertex {N}",
|
||||
"property float x",
|
||||
"property float y",
|
||||
"property float z",
|
||||
]
|
||||
if self.colors is not None:
|
||||
header_lines += ["property uchar red", "property uchar green", "property uchar blue"]
|
||||
if self.confidence is not None:
|
||||
header_lines.append("property float confidence")
|
||||
if self.view_id is not None:
|
||||
header_lines.append("property int view_id")
|
||||
header_lines.append("end_header")
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(header_lines) + "\n")
|
||||
for i in range(N):
|
||||
parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"]
|
||||
if self.colors is not None:
|
||||
r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8)
|
||||
parts.append(f"{r} {g} {b}")
|
||||
if self.confidence is not None:
|
||||
parts.append(f"{self.confidence[i]}")
|
||||
if self.view_id is not None:
|
||||
parts.append(f"{int(self.view_id[i])}")
|
||||
f.write(" ".join(parts) + "\n")
|
||||
return path
|
||||
259
comfy_api_sealed_worker/trimesh_types.py
Normal file
259
comfy_api_sealed_worker/trimesh_types.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TrimeshData:
|
||||
"""Triangular mesh payload for cross-process transfer.
|
||||
|
||||
Lightweight carrier for mesh geometry that does not depend on the
|
||||
``trimesh`` library. Serializers create this on the host side;
|
||||
isolated child processes convert to/from ``trimesh.Trimesh`` as needed.
|
||||
|
||||
Supports both ColorVisuals (vertex_colors) and TextureVisuals
|
||||
(uv + material with textures).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vertices: np.ndarray,
|
||||
faces: np.ndarray,
|
||||
vertex_normals: np.ndarray | None = None,
|
||||
face_normals: np.ndarray | None = None,
|
||||
vertex_colors: np.ndarray | None = None,
|
||||
uv: np.ndarray | None = None,
|
||||
material: dict | None = None,
|
||||
vertex_attributes: dict | None = None,
|
||||
face_attributes: dict | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
self.vertices = np.ascontiguousarray(vertices, dtype=np.float64)
|
||||
self.faces = np.ascontiguousarray(faces, dtype=np.int64)
|
||||
self.vertex_normals = (
|
||||
np.ascontiguousarray(vertex_normals, dtype=np.float64)
|
||||
if vertex_normals is not None
|
||||
else None
|
||||
)
|
||||
self.face_normals = (
|
||||
np.ascontiguousarray(face_normals, dtype=np.float64)
|
||||
if face_normals is not None
|
||||
else None
|
||||
)
|
||||
self.vertex_colors = (
|
||||
np.ascontiguousarray(vertex_colors, dtype=np.uint8)
|
||||
if vertex_colors is not None
|
||||
else None
|
||||
)
|
||||
self.uv = (
|
||||
np.ascontiguousarray(uv, dtype=np.float64)
|
||||
if uv is not None
|
||||
else None
|
||||
)
|
||||
self.material = material
|
||||
self.vertex_attributes = vertex_attributes or {}
|
||||
self.face_attributes = face_attributes or {}
|
||||
self.metadata = self._detensorize_dict(metadata) if metadata else {}
|
||||
|
||||
@staticmethod
|
||||
def _detensorize_dict(d):
|
||||
"""Recursively convert any tensors in a dict back to numpy arrays."""
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if hasattr(v, "numpy"):
|
||||
result[k] = v.cpu().numpy() if hasattr(v, "cpu") else v.numpy()
|
||||
elif isinstance(v, dict):
|
||||
result[k] = TrimeshData._detensorize_dict(v)
|
||||
elif isinstance(v, list):
|
||||
result[k] = [
|
||||
item.cpu().numpy() if hasattr(item, "numpy") and hasattr(item, "cpu")
|
||||
else item.numpy() if hasattr(item, "numpy")
|
||||
else item
|
||||
for item in v
|
||||
]
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _to_numpy(arr, dtype):
|
||||
if arr is None:
|
||||
return None
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
|
||||
return np.ascontiguousarray(arr, dtype=dtype)
|
||||
|
||||
@property
|
||||
def num_vertices(self) -> int:
|
||||
return self.vertices.shape[0]
|
||||
|
||||
@property
|
||||
def num_faces(self) -> int:
|
||||
return self.faces.shape[0]
|
||||
|
||||
@property
|
||||
def has_texture(self) -> bool:
|
||||
return self.uv is not None and self.material is not None
|
||||
|
||||
def to_trimesh(self):
|
||||
"""Convert to trimesh.Trimesh (requires trimesh in the environment)."""
|
||||
import trimesh
|
||||
from trimesh.visual import TextureVisuals
|
||||
|
||||
kwargs = {}
|
||||
if self.vertex_normals is not None:
|
||||
kwargs["vertex_normals"] = self.vertex_normals
|
||||
if self.face_normals is not None:
|
||||
kwargs["face_normals"] = self.face_normals
|
||||
if self.metadata:
|
||||
kwargs["metadata"] = self.metadata
|
||||
|
||||
mesh = trimesh.Trimesh(
|
||||
vertices=self.vertices, faces=self.faces, process=False, **kwargs
|
||||
)
|
||||
|
||||
# Reconstruct visual
|
||||
if self.has_texture:
|
||||
material = self._dict_to_material(self.material)
|
||||
mesh.visual = TextureVisuals(uv=self.uv, material=material)
|
||||
elif self.vertex_colors is not None:
|
||||
mesh.visual.vertex_colors = self.vertex_colors
|
||||
|
||||
for k, v in self.vertex_attributes.items():
|
||||
mesh.vertex_attributes[k] = v
|
||||
|
||||
for k, v in self.face_attributes.items():
|
||||
mesh.face_attributes[k] = v
|
||||
|
||||
return mesh
|
||||
|
||||
@staticmethod
|
||||
def _material_to_dict(material) -> dict:
|
||||
"""Serialize a trimesh material to a plain dict."""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from trimesh.visual.material import PBRMaterial, SimpleMaterial
|
||||
|
||||
result = {"type": type(material).__name__, "name": getattr(material, "name", None)}
|
||||
|
||||
if isinstance(material, PBRMaterial):
|
||||
result["baseColorFactor"] = material.baseColorFactor
|
||||
result["metallicFactor"] = material.metallicFactor
|
||||
result["roughnessFactor"] = material.roughnessFactor
|
||||
result["emissiveFactor"] = material.emissiveFactor
|
||||
result["alphaMode"] = material.alphaMode
|
||||
result["alphaCutoff"] = material.alphaCutoff
|
||||
result["doubleSided"] = material.doubleSided
|
||||
|
||||
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
|
||||
"metallicRoughnessTexture", "occlusionTexture"):
|
||||
tex = getattr(material, tex_name, None)
|
||||
if tex is not None:
|
||||
buf = BytesIO()
|
||||
tex.save(buf, format="PNG")
|
||||
result[tex_name] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
elif isinstance(material, SimpleMaterial):
|
||||
result["main_color"] = list(material.main_color) if material.main_color is not None else None
|
||||
result["glossiness"] = material.glossiness
|
||||
if hasattr(material, "image") and material.image is not None:
|
||||
buf = BytesIO()
|
||||
material.image.save(buf, format="PNG")
|
||||
result["image"] = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_material(d: dict):
|
||||
"""Reconstruct a trimesh material from a plain dict."""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from trimesh.visual.material import PBRMaterial, SimpleMaterial
|
||||
|
||||
mat_type = d.get("type", "PBRMaterial")
|
||||
|
||||
if mat_type == "PBRMaterial":
|
||||
kwargs = {
|
||||
"name": d.get("name"),
|
||||
"baseColorFactor": d.get("baseColorFactor"),
|
||||
"metallicFactor": d.get("metallicFactor"),
|
||||
"roughnessFactor": d.get("roughnessFactor"),
|
||||
"emissiveFactor": d.get("emissiveFactor"),
|
||||
"alphaMode": d.get("alphaMode"),
|
||||
"alphaCutoff": d.get("alphaCutoff"),
|
||||
"doubleSided": d.get("doubleSided"),
|
||||
}
|
||||
for tex_name in ("baseColorTexture", "normalTexture", "emissiveTexture",
|
||||
"metallicRoughnessTexture", "occlusionTexture"):
|
||||
if tex_name in d and d[tex_name] is not None:
|
||||
img = Image.open(BytesIO(base64.b64decode(d[tex_name])))
|
||||
kwargs[tex_name] = img
|
||||
return PBRMaterial(**{k: v for k, v in kwargs.items() if v is not None})
|
||||
|
||||
elif mat_type == "SimpleMaterial":
|
||||
kwargs = {
|
||||
"name": d.get("name"),
|
||||
"glossiness": d.get("glossiness"),
|
||||
}
|
||||
if d.get("main_color") is not None:
|
||||
kwargs["diffuse"] = d["main_color"]
|
||||
if d.get("image") is not None:
|
||||
kwargs["image"] = Image.open(BytesIO(base64.b64decode(d["image"])))
|
||||
return SimpleMaterial(**kwargs)
|
||||
|
||||
raise ValueError(f"Unknown material type: {mat_type}")
|
||||
|
||||
@classmethod
|
||||
def from_trimesh(cls, mesh) -> TrimeshData:
|
||||
"""Create from a trimesh.Trimesh object."""
|
||||
from trimesh.visual.texture import TextureVisuals
|
||||
|
||||
vertex_normals = None
|
||||
if mesh._cache.cache.get("vertex_normals") is not None:
|
||||
vertex_normals = np.asarray(mesh.vertex_normals)
|
||||
|
||||
face_normals = None
|
||||
if mesh._cache.cache.get("face_normals") is not None:
|
||||
face_normals = np.asarray(mesh.face_normals)
|
||||
|
||||
vertex_colors = None
|
||||
uv = None
|
||||
material = None
|
||||
|
||||
if isinstance(mesh.visual, TextureVisuals):
|
||||
if mesh.visual.uv is not None:
|
||||
uv = np.asarray(mesh.visual.uv, dtype=np.float64)
|
||||
if mesh.visual.material is not None:
|
||||
material = cls._material_to_dict(mesh.visual.material)
|
||||
else:
|
||||
try:
|
||||
vc = mesh.visual.vertex_colors
|
||||
if vc is not None and len(vc) > 0:
|
||||
vertex_colors = np.asarray(vc, dtype=np.uint8)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
va = {}
|
||||
if hasattr(mesh, "vertex_attributes") and mesh.vertex_attributes:
|
||||
for k, v in mesh.vertex_attributes.items():
|
||||
va[k] = np.asarray(v) if hasattr(v, "__array__") else v
|
||||
|
||||
fa = {}
|
||||
if hasattr(mesh, "face_attributes") and mesh.face_attributes:
|
||||
for k, v in mesh.face_attributes.items():
|
||||
fa[k] = np.asarray(v) if hasattr(v, "__array__") else v
|
||||
|
||||
return cls(
|
||||
vertices=np.asarray(mesh.vertices),
|
||||
faces=np.asarray(mesh.faces),
|
||||
vertex_normals=vertex_normals,
|
||||
face_normals=face_normals,
|
||||
vertex_colors=vertex_colors,
|
||||
uv=uv,
|
||||
material=material,
|
||||
vertex_attributes=va if va else None,
|
||||
face_attributes=fa if fa else None,
|
||||
metadata=mesh.metadata if mesh.metadata else None,
|
||||
)
|
||||
@@ -6,7 +6,6 @@ import comfy.model_management
|
||||
import torch
|
||||
import math
|
||||
import nodes
|
||||
import comfy.ldm.flux.math
|
||||
|
||||
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -232,68 +231,6 @@ class Flux2Scheduler(io.ComfyNode):
|
||||
sigmas = get_schedule(steps, round(seq_len))
|
||||
return io.NodeOutput(sigmas)
|
||||
|
||||
class KV_Attn_Input:
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def __call__(self, q, k, v, extra_options, **kwargs):
|
||||
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
|
||||
if len(reference_image_num_tokens) == 0:
|
||||
return {}
|
||||
|
||||
ref_toks = sum(reference_image_num_tokens)
|
||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
||||
if cache_key in self.cache:
|
||||
kk, vv = self.cache[cache_key]
|
||||
self.set_cache = False
|
||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
||||
|
||||
self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
|
||||
self.set_cache = True
|
||||
return {"q": q, "k": k, "v": v}
|
||||
|
||||
def cleanup(self):
|
||||
self.cache = {}
|
||||
|
||||
|
||||
class FluxKVCache(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="FluxKVCache",
|
||||
display_name="Flux KV Cache",
|
||||
description="Enables KV Cache optimization for reference images on Flux family models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to use KV Cache on."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
input_patch_obj = KV_Attn_Input()
|
||||
|
||||
def model_input_patch(inputs):
|
||||
if len(input_patch_obj.cache) > 0:
|
||||
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
|
||||
if ref_image_tokens > 0:
|
||||
img = inputs["img"]
|
||||
inputs["img"] = img[:, :-ref_image_tokens]
|
||||
return inputs
|
||||
|
||||
m.set_model_attn1_patch(input_patch_obj)
|
||||
m.set_model_post_input_patch(model_input_patch)
|
||||
if hasattr(model.model.diffusion_model, "params"):
|
||||
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
|
||||
else:
|
||||
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class FluxExtension(ComfyExtension):
|
||||
@override
|
||||
@@ -306,7 +243,6 @@ class FluxExtension(ComfyExtension):
|
||||
FluxKontextMultiReferenceLatentMethod,
|
||||
EmptyFlux2LatentImage,
|
||||
Flux2Scheduler,
|
||||
FluxKVCache,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io, UI
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
|
||||
hex_color = hex_color.lstrip("#")
|
||||
if len(hex_color) != 6:
|
||||
return (0.0, 0.0, 0.0)
|
||||
r = int(hex_color[0:2], 16) / 255.0
|
||||
g = int(hex_color[2:4], 16) / 255.0
|
||||
b = int(hex_color[4:6], 16) / 255.0
|
||||
return (r, g, b)
|
||||
|
||||
|
||||
class PainterNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Painter",
|
||||
display_name="Painter",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
tooltip="Optional base image to paint over",
|
||||
),
|
||||
io.String.Input(
|
||||
"mask",
|
||||
default="",
|
||||
socketless=True,
|
||||
extra_dict={"widgetType": "PAINTER", "image_upload": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"width",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"height",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.Color.Input("bg_color", default="#000000"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("IMAGE"),
|
||||
io.Mask.Output("MASK"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
|
||||
if image is not None:
|
||||
base_image = image[:1]
|
||||
h, w = base_image.shape[1], base_image.shape[2]
|
||||
else:
|
||||
h, w = height, width
|
||||
r, g, b = hex_to_rgb(bg_color)
|
||||
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
|
||||
base_image[0, :, :, 0] = r
|
||||
base_image[0, :, :, 1] = g
|
||||
base_image[0, :, :, 2] = b
|
||||
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
painter_img = node_helpers.pillow(Image.open, mask_path)
|
||||
painter_img = painter_img.convert("RGBA")
|
||||
|
||||
if painter_img.size != (w, h):
|
||||
painter_img = painter_img.resize((w, h), Image.LANCZOS)
|
||||
|
||||
painter_np = np.array(painter_img).astype(np.float32) / 255.0
|
||||
painter_rgb = painter_np[:, :, :3]
|
||||
painter_alpha = painter_np[:, :, 3:4]
|
||||
|
||||
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
|
||||
|
||||
base_np = base_image[0].cpu().numpy()
|
||||
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
|
||||
out_image = torch.from_numpy(composited).unsqueeze(0)
|
||||
else:
|
||||
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
|
||||
out_image = base_image
|
||||
|
||||
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
if os.path.exists(mask_path):
|
||||
m = hashlib.sha256()
|
||||
with open(mask_path, "rb") as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
class PainterExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [PainterNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return PainterExtension()
|
||||
40
comfy_extras/nodes_save_npz.py
Normal file
40
comfy_extras/nodes_save_npz.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import io
|
||||
from comfy_api.latest._util.npz_types import NPZ
|
||||
|
||||
|
||||
class SaveNPZ(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveNPZ",
|
||||
display_name="Save NPZ",
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.Npz.Input("npz"),
|
||||
io.String.Input("filename_prefix", default="da3_streaming/ComfyUI"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, npz: NPZ, filename_prefix: str) -> io.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, folder_paths.get_output_directory()
|
||||
)
|
||||
batch_dir = os.path.join(full_output_folder, f"{filename}_{counter:05}")
|
||||
os.makedirs(batch_dir, exist_ok=True)
|
||||
filenames = []
|
||||
for i, frame_bytes in enumerate(npz.frames):
|
||||
f = f"frame_{i:06d}.npz"
|
||||
with open(os.path.join(batch_dir, f), "wb") as fh:
|
||||
fh.write(frame_bytes)
|
||||
filenames.append(f)
|
||||
return io.NodeOutput(ui={"npz_files": [{"folder": os.path.join(subfolder, f"{filename}_{counter:05}"), "count": len(filenames), "type": "output"}]})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveNPZ": SaveNPZ,
|
||||
}
|
||||
34
comfy_extras/nodes_save_ply.py
Normal file
34
comfy_extras/nodes_save_ply.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import io
|
||||
from comfy_api.latest._util.ply_types import PLY
|
||||
|
||||
|
||||
class SavePLY(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SavePLY",
|
||||
display_name="Save PLY",
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.Ply.Input("ply"),
|
||||
io.String.Input("filename_prefix", default="pointcloud/ComfyUI"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ply: PLY, filename_prefix: str) -> io.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, folder_paths.get_output_directory()
|
||||
)
|
||||
f = f"{filename}_{counter:05}_.ply"
|
||||
ply.save_to(os.path.join(full_output_folder, f))
|
||||
return io.NodeOutput(ui={"pointclouds": [{"filename": f, "subfolder": subfolder, "type": "output"}]})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SavePLY": SavePLY,
|
||||
}
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.17.2"
|
||||
__version__ = "0.16.4"
|
||||
|
||||
@@ -92,7 +92,7 @@ if args.cuda_malloc:
|
||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||
if env_var is None:
|
||||
env_var = "backend:cudaMallocAsync"
|
||||
else:
|
||||
elif not args.use_process_isolation:
|
||||
env_var += ",backend:cudaMallocAsync"
|
||||
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||
|
||||
145
execution.py
145
execution.py
@@ -1,7 +1,9 @@
|
||||
import copy
|
||||
import gc
|
||||
import heapq
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -41,6 +43,8 @@ from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io, _io
|
||||
|
||||
_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = False
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
@@ -261,20 +265,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
pre_execute_cb(index)
|
||||
# V3
|
||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||
# if is just a class, then assign no state, just create clone
|
||||
if is_class(obj):
|
||||
type_obj = obj
|
||||
obj.VALIDATE_CLASS()
|
||||
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
|
||||
# otherwise, use class instance to populate/reuse some fields
|
||||
# Check for isolated node - skip validation and class cloning
|
||||
if hasattr(obj, "_pyisolate_extension"):
|
||||
# Isolated Node: The stub is just a proxy; real validation happens in child process
|
||||
if v3_data is not None:
|
||||
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||||
# Inject hidden inputs so they're available in the isolated child process
|
||||
inputs.update(v3_data.get("hidden_inputs", {}))
|
||||
f = getattr(obj, func)
|
||||
# Standard V3 Node (Existing Logic)
|
||||
|
||||
else:
|
||||
type_obj = type(obj)
|
||||
type_obj.VALIDATE_CLASS()
|
||||
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
|
||||
f = make_locked_method_func(type_obj, func, class_clone)
|
||||
# in case of dynamic inputs, restructure inputs to expected nested dict
|
||||
if v3_data is not None:
|
||||
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||||
# if is just a class, then assign no resources or state, just create clone
|
||||
if is_class(obj):
|
||||
type_obj = obj
|
||||
obj.VALIDATE_CLASS()
|
||||
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
|
||||
# otherwise, use class instance to populate/reuse some fields
|
||||
else:
|
||||
type_obj = type(obj)
|
||||
type_obj.VALIDATE_CLASS()
|
||||
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
|
||||
f = make_locked_method_func(type_obj, func, class_clone)
|
||||
# in case of dynamic inputs, restructure inputs to expected nested dict
|
||||
if v3_data is not None:
|
||||
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||||
# V1
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
@@ -527,7 +542,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.control.analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
vbar_lib = getattr(comfy_aimdo.model_vbar, "lib", None)
|
||||
if vbar_lib is not None:
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
else:
|
||||
global _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED
|
||||
if not _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED:
|
||||
logging.warning(
|
||||
"DynamicVRAM backend unavailable for watermark reset; "
|
||||
"skipping vbar reset for this process."
|
||||
)
|
||||
_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = True
|
||||
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
@@ -536,6 +561,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
unblock()
|
||||
|
||||
# Keep isolation node execution deterministic by default, but allow
|
||||
# opt-out for diagnostics.
|
||||
isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes")
|
||||
if args.use_process_isolation and isolation_sequential:
|
||||
await await_completion()
|
||||
return await execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs)
|
||||
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
@@ -647,6 +680,46 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
|
||||
async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None:
|
||||
if not args.use_process_isolation:
|
||||
return
|
||||
try:
|
||||
from comfy.isolation import notify_execution_graph
|
||||
await notify_execution_graph(class_types, caches=self.caches.all)
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logging.debug("][ EX:notify_execution_graph failed", exc_info=True)
|
||||
|
||||
async def _flush_running_extensions_transport_state_safe(self) -> None:
|
||||
if not args.use_process_isolation:
|
||||
return
|
||||
try:
|
||||
from comfy.isolation import flush_running_extensions_transport_state
|
||||
await flush_running_extensions_transport_state()
|
||||
except Exception:
|
||||
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
|
||||
|
||||
async def _wait_model_patcher_quiescence_safe(
|
||||
self,
|
||||
*,
|
||||
fail_loud: bool = False,
|
||||
timeout_ms: int = 120000,
|
||||
marker: str = "EX:wait_model_patcher_idle",
|
||||
) -> None:
|
||||
if not args.use_process_isolation:
|
||||
return
|
||||
try:
|
||||
from comfy.isolation import wait_for_model_patcher_quiescence
|
||||
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker
|
||||
)
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True)
|
||||
|
||||
def add_message(self, event, data: dict, broadcast: bool):
|
||||
data = {
|
||||
**data,
|
||||
@@ -688,6 +761,18 @@ class PromptExecutor:
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
if args.use_process_isolation:
|
||||
# Update RPC event loops for all isolated extensions.
|
||||
# This is critical for serial workflow execution - each asyncio.run() creates
|
||||
# a new event loop, and RPC instances must be updated to use it.
|
||||
try:
|
||||
from comfy.isolation import update_rpc_event_loops
|
||||
update_rpc_event_loops()
|
||||
except ImportError:
|
||||
pass # Isolation not available
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
|
||||
nodes.interrupt_processing(False)
|
||||
@@ -700,6 +785,25 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
if args.use_process_isolation:
|
||||
try:
|
||||
# Boundary cleanup runs at the start of the next workflow in
|
||||
# isolation mode, matching non-isolated "next prompt" timing.
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||
await self._wait_model_patcher_quiescence_safe(
|
||||
fail_loud=False,
|
||||
timeout_ms=120000,
|
||||
marker="EX:boundary_cleanup_wait_idle",
|
||||
)
|
||||
await self._flush_running_extensions_transport_state_safe()
|
||||
comfy.model_management.unload_all_models()
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
comfy.model_management.cleanup_models()
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
except Exception:
|
||||
logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True)
|
||||
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
@@ -727,6 +831,18 @@ class PromptExecutor:
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
if args.use_process_isolation:
|
||||
pending_class_types = set()
|
||||
for node_id in execution_list.pendingNodes.keys():
|
||||
class_type = dynamic_prompt.get_node(node_id)["class_type"]
|
||||
pending_class_types.add(class_type)
|
||||
await self._wait_model_patcher_quiescence_safe(
|
||||
fail_loud=True,
|
||||
timeout_ms=120000,
|
||||
marker="EX:notify_graph_wait_idle",
|
||||
)
|
||||
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
@@ -757,6 +873,7 @@ class PromptExecutor:
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
109
main.py
109
main.py
@@ -1,7 +1,21 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
IS_PYISOLATE_CHILD = os.environ.get("PYISOLATE_CHILD") == "1"
|
||||
|
||||
if __name__ == "__main__" and IS_PYISOLATE_CHILD:
|
||||
del os.environ["PYISOLATE_CHILD"]
|
||||
IS_PYISOLATE_CHILD = False
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
if CURRENT_DIR not in sys.path:
|
||||
sys.path.insert(0, CURRENT_DIR)
|
||||
|
||||
IS_PRIMARY_PROCESS = (not IS_PYISOLATE_CHILD) and __name__ == "__main__"
|
||||
|
||||
import comfy.options
|
||||
comfy.options.enable_args_parsing()
|
||||
|
||||
import os
|
||||
import importlib.util
|
||||
import shutil
|
||||
import importlib.metadata
|
||||
@@ -10,7 +24,7 @@ import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
import utils.extra_config # noqa: F401
|
||||
from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
@@ -20,12 +34,45 @@ from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
from app.database.db import init_db, dependencies_available
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
import comfy_aimdo.control
|
||||
|
||||
if enables_dynamic_vram():
|
||||
if not comfy_aimdo.control.init():
|
||||
logging.warning(
|
||||
"DynamicVRAM requested, but comfy-aimdo failed to initialize early. "
|
||||
"Will fall back to legacy model loading if device init fails."
|
||||
)
|
||||
|
||||
if '--use-process-isolation' in sys.argv:
|
||||
from comfy.isolation import initialize_proxies
|
||||
initialize_proxies()
|
||||
|
||||
# Explicitly register the ComfyUI adapter for pyisolate (v1.0 architecture)
|
||||
try:
|
||||
import pyisolate
|
||||
from comfy.isolation.adapter import ComfyUIAdapter
|
||||
pyisolate.register_adapter(ComfyUIAdapter())
|
||||
logging.info("PyIsolate adapter registered: comfyui")
|
||||
except ImportError:
|
||||
logging.warning("PyIsolate not installed or version too old for explicit registration")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to register PyIsolate adapter: {e}")
|
||||
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ:
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:native'
|
||||
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
|
||||
if IS_PRIMARY_PROCESS:
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||
|
||||
@@ -91,14 +138,15 @@ if args.enable_manager:
|
||||
|
||||
|
||||
def apply_custom_paths():
|
||||
from utils import extra_config # Deferred import - spawn re-runs main.py
|
||||
# extra model paths
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||
extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||
|
||||
if args.extra_model_paths_config:
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
utils.extra_config.load_extra_path_config(config_path)
|
||||
extra_config.load_extra_path_config(config_path)
|
||||
|
||||
# --output-directory, --input-directory, --user-directory
|
||||
if args.output_directory:
|
||||
@@ -171,15 +219,17 @@ def execute_prestartup_script():
|
||||
else:
|
||||
import_message = " (PRESTARTUP FAILED)"
|
||||
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
||||
logging.info("")
|
||||
logging.info("")
|
||||
|
||||
apply_custom_paths()
|
||||
init_mime_types()
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
apply_custom_paths()
|
||||
init_mime_types()
|
||||
|
||||
if args.enable_manager:
|
||||
if args.enable_manager and not IS_PYISOLATE_CHILD:
|
||||
comfyui_manager.prestartup()
|
||||
|
||||
execute_prestartup_script()
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
execute_prestartup_script()
|
||||
|
||||
|
||||
# Main code
|
||||
@@ -190,18 +240,18 @@ import gc
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
|
||||
import comfy.utils
|
||||
from app.assets.seeder import asset_seeder
|
||||
|
||||
import execution
|
||||
import server
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
import execution
|
||||
import server
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
import comfy.memory_management
|
||||
import comfy.model_patcher
|
||||
@@ -417,6 +467,10 @@ def start_comfyui(asyncio_loop=None):
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
if args.use_process_isolation:
|
||||
from comfy.isolation import start_isolation_loading_early
|
||||
start_isolation_loading_early(asyncio_loop)
|
||||
|
||||
if args.enable_manager and not args.disable_manager_ui:
|
||||
comfyui_manager.start()
|
||||
|
||||
@@ -461,12 +515,13 @@ def start_comfyui(asyncio_loop=None):
|
||||
if __name__ == "__main__":
|
||||
# Running directly, just start ComfyUI.
|
||||
logging.info("Python version: {}".format(sys.version))
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
for package in ("comfy-aimdo", "comfy-kitchen"):
|
||||
try:
|
||||
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
|
||||
except:
|
||||
pass
|
||||
if not IS_PYISOLATE_CHILD:
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
for package in ("comfy-aimdo", "comfy-kitchen"):
|
||||
try:
|
||||
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
|
||||
except:
|
||||
pass
|
||||
|
||||
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
||||
|
||||
45
nodes.py
45
nodes.py
@@ -1925,6 +1925,7 @@ class ImageInvert:
|
||||
|
||||
class ImageBatch:
|
||||
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -2306,6 +2307,27 @@ async def init_external_custom_nodes():
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
whitelist = set()
|
||||
isolated_module_paths = set()
|
||||
if args.use_process_isolation:
|
||||
from pathlib import Path
|
||||
from comfy.isolation import await_isolation_loading, get_claimed_paths
|
||||
from comfy.isolation.host_policy import load_host_policy
|
||||
|
||||
# Load Global Host Policy
|
||||
host_policy = load_host_policy(Path(folder_paths.base_path))
|
||||
whitelist_dict = host_policy.get("whitelist", {})
|
||||
# Normalize whitelist keys to lowercase for case-insensitive matching
|
||||
# (matches ComfyUI-Manager's normalization: project.name.strip().lower())
|
||||
whitelist = set(k.strip().lower() for k in whitelist_dict.keys())
|
||||
logging.info(f"][ Loaded Whitelist: {len(whitelist)} nodes allowed.")
|
||||
|
||||
isolated_specs = await await_isolation_loading()
|
||||
for spec in isolated_specs:
|
||||
NODE_CLASS_MAPPINGS.setdefault(spec.node_name, spec.stub_class)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.setdefault(spec.node_name, spec.display_name)
|
||||
isolated_module_paths = get_claimed_paths()
|
||||
|
||||
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
@@ -2329,6 +2351,16 @@ async def init_external_custom_nodes():
|
||||
logging.info(f"Blocked by policy: {module_path}")
|
||||
continue
|
||||
|
||||
if args.use_process_isolation:
|
||||
if Path(module_path).resolve() in isolated_module_paths:
|
||||
continue
|
||||
|
||||
# Tri-State Enforcement: If not Isolated (checked above), MUST be Whitelisted.
|
||||
# Normalize to lowercase for case-insensitive matching (matches ComfyUI-Manager)
|
||||
if possible_module.strip().lower() not in whitelist:
|
||||
logging.warning(f"][ REJECTED: Node '{possible_module}' is blocked by security policy (not whitelisted/isolated).")
|
||||
continue
|
||||
|
||||
time_before = time.perf_counter()
|
||||
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
@@ -2343,6 +2375,14 @@ async def init_external_custom_nodes():
|
||||
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
||||
logging.info("")
|
||||
|
||||
if args.use_process_isolation:
|
||||
from comfy.isolation import isolated_node_timings
|
||||
if isolated_node_timings:
|
||||
logging.info("\nImport times for isolated custom nodes:")
|
||||
for timing, path, count in sorted(isolated_node_timings):
|
||||
logging.info("{:6.1f} seconds: {} ({})".format(timing, path, count))
|
||||
logging.info("")
|
||||
|
||||
async def init_builtin_extra_nodes():
|
||||
"""
|
||||
Initializes the built-in extra nodes in ComfyUI.
|
||||
@@ -2415,6 +2455,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_wan.py",
|
||||
"nodes_lotus.py",
|
||||
"nodes_hunyuan3d.py",
|
||||
"nodes_save_ply.py",
|
||||
"nodes_save_npz.py",
|
||||
"nodes_primitive.py",
|
||||
"nodes_cfg.py",
|
||||
"nodes_optimalsteps.py",
|
||||
@@ -2435,7 +2477,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_audio_encoder.py",
|
||||
"nodes_rope.py",
|
||||
"nodes_logic.py",
|
||||
"nodes_resolution.py",
|
||||
"nodes_nop.py",
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
@@ -2443,14 +2484,12 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_zimage.py",
|
||||
"nodes_glsl.py",
|
||||
"nodes_lora_debug.py",
|
||||
"nodes_textgen.py",
|
||||
"nodes_color.py",
|
||||
"nodes_toolkit.py",
|
||||
"nodes_replacements.py",
|
||||
"nodes_nag.py",
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_painter.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.17.2"
|
||||
version = "0.16.4"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
@@ -10,6 +10,17 @@ homepage = "https://www.comfy.org/"
|
||||
repository = "https://github.com/comfyanonymous/ComfyUI"
|
||||
documentation = "https://docs.comfy.org/"
|
||||
|
||||
[tool.comfy.host]
|
||||
sandbox_mode = "disabled"
|
||||
allow_network = false
|
||||
writable_paths = ["/dev/shm", "/tmp"]
|
||||
|
||||
[tool.comfy.host.whitelist]
|
||||
"ComfyUI-GGUF" = "*"
|
||||
"ComfyUI-KJNodes" = "*"
|
||||
"ComfyUI-Manager" = "*"
|
||||
"websocket_image_save.py" = "*"
|
||||
|
||||
[tool.ruff]
|
||||
lint.select = [
|
||||
"N805", # invalid-first-argument-name-for-method
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.41.20
|
||||
comfyui-workflow-templates==0.9.21
|
||||
comfyui-frontend-package==1.41.16
|
||||
comfyui-workflow-templates==0.9.18
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
@@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.2.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
@@ -35,3 +35,5 @@ pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
glfw
|
||||
|
||||
pyisolate==0.10.0
|
||||
|
||||
48
server.py
48
server.py
@@ -3,7 +3,6 @@ import sys
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
|
||||
import nodes
|
||||
import folder_paths
|
||||
import execution
|
||||
@@ -196,6 +195,8 @@ def create_block_external_middleware():
|
||||
class PromptServer():
|
||||
def __init__(self, loop):
|
||||
PromptServer.instance = self
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.model_file_manager = ModelFileManager()
|
||||
@@ -346,6 +347,17 @@ class PromptServer():
|
||||
extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
|
||||
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
||||
|
||||
# Include JS files from proxied web directories (isolated nodes)
|
||||
if args.use_process_isolation:
|
||||
from comfy.isolation.proxies.web_directory_proxy import get_web_directory_cache
|
||||
cache = get_web_directory_cache()
|
||||
for ext_name in cache.extension_names:
|
||||
for entry in cache.list_files(ext_name):
|
||||
if entry["relative_path"].endswith(".js"):
|
||||
extensions.append(
|
||||
"/extensions/" + urllib.parse.quote(ext_name) + "/" + entry["relative_path"]
|
||||
)
|
||||
|
||||
return web.json_response(extensions)
|
||||
|
||||
def get_dir_by_type(dir_type):
|
||||
@@ -1021,6 +1033,40 @@ class PromptServer():
|
||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
||||
|
||||
# Add dynamic handler for proxied web directories (isolated nodes)
|
||||
if args.use_process_isolation:
|
||||
from comfy.isolation.proxies.web_directory_proxy import (
|
||||
get_web_directory_cache,
|
||||
ALLOWED_EXTENSIONS,
|
||||
)
|
||||
|
||||
async def proxied_web_handler(request):
|
||||
ext_name = request.match_info["ext_name"]
|
||||
file_path = request.match_info["file_path"]
|
||||
|
||||
suffix = os.path.splitext(file_path)[1].lower()
|
||||
if suffix not in ALLOWED_EXTENSIONS:
|
||||
return web.Response(status=403, text="Forbidden file type")
|
||||
|
||||
cache = get_web_directory_cache()
|
||||
result = cache.get_file(ext_name, file_path)
|
||||
if result is None:
|
||||
return web.Response(status=404, text="Not found")
|
||||
|
||||
content_type = {
|
||||
".js": "application/javascript",
|
||||
".css": "text/css",
|
||||
".html": "text/html",
|
||||
".json": "application/json",
|
||||
}.get(suffix, "application/octet-stream")
|
||||
|
||||
return web.Response(body=result, content_type=content_type)
|
||||
|
||||
self.app.router.add_get(
|
||||
"/extensions/{ext_name}/{file_path:.*}",
|
||||
proxied_web_handler,
|
||||
)
|
||||
|
||||
installed_templates_version = FrontendManager.get_installed_templates_version()
|
||||
use_legacy_templates = True
|
||||
if installed_templates_version:
|
||||
|
||||
209
tests/isolation/conda_sealed_worker/__init__.py
Normal file
209
tests/isolation/conda_sealed_worker/__init__.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# pylint: disable=import-outside-toplevel,import-error
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _artifact_dir() -> Path | None:
|
||||
raw = os.environ.get("PYISOLATE_ARTIFACT_DIR")
|
||||
if not raw:
|
||||
return None
|
||||
path = Path(raw)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def _write_artifact(name: str, content: str) -> None:
|
||||
artifact_dir = _artifact_dir()
|
||||
if artifact_dir is None:
|
||||
return
|
||||
(artifact_dir / name).write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def _contains_tensor_marker(value: Any) -> bool:
|
||||
if isinstance(value, dict):
|
||||
if value.get("__type__") == "TensorValue":
|
||||
return True
|
||||
return any(_contains_tensor_marker(v) for v in value.values())
|
||||
if isinstance(value, (list, tuple)):
|
||||
return any(_contains_tensor_marker(v) for v in value)
|
||||
return False
|
||||
|
||||
|
||||
class InspectRuntimeNode:
|
||||
RETURN_TYPES = (
|
||||
"STRING",
|
||||
"STRING",
|
||||
"BOOLEAN",
|
||||
"BOOLEAN",
|
||||
"STRING",
|
||||
"STRING",
|
||||
"BOOLEAN",
|
||||
)
|
||||
RETURN_NAMES = (
|
||||
"path_dump",
|
||||
"runtime_report",
|
||||
"saw_comfy_root",
|
||||
"imported_comfy_wrapper",
|
||||
"comfy_module_dump",
|
||||
"python_exe",
|
||||
"saw_user_site",
|
||||
)
|
||||
FUNCTION = "inspect"
|
||||
CATEGORY = "PyIsolated/SealedWorker"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
|
||||
return {"required": {}}
|
||||
|
||||
def inspect(self) -> tuple[str, str, bool, bool, str, str, bool]:
|
||||
import cfgrib
|
||||
import eccodes
|
||||
import xarray as xr
|
||||
|
||||
path_dump = "\n".join(sys.path)
|
||||
comfy_root = "/home/johnj/ComfyUI"
|
||||
saw_comfy_root = any(
|
||||
entry == comfy_root
|
||||
or entry.startswith(f"{comfy_root}/comfy")
|
||||
or entry.startswith(f"{comfy_root}/.venv")
|
||||
for entry in sys.path
|
||||
)
|
||||
imported_comfy_wrapper = "comfy.isolation.extension_wrapper" in sys.modules
|
||||
comfy_module_dump = "\n".join(
|
||||
sorted(name for name in sys.modules if name.startswith("comfy"))
|
||||
)
|
||||
saw_user_site = any("/.local/lib/" in entry for entry in sys.path)
|
||||
python_exe = sys.executable
|
||||
|
||||
runtime_lines = [
|
||||
"Conda sealed worker runtime probe",
|
||||
f"python_exe={python_exe}",
|
||||
f"xarray_origin={getattr(xr, '__file__', '<missing>')}",
|
||||
f"cfgrib_origin={getattr(cfgrib, '__file__', '<missing>')}",
|
||||
f"eccodes_origin={getattr(eccodes, '__file__', '<missing>')}",
|
||||
f"saw_comfy_root={saw_comfy_root}",
|
||||
f"imported_comfy_wrapper={imported_comfy_wrapper}",
|
||||
f"saw_user_site={saw_user_site}",
|
||||
]
|
||||
runtime_report = "\n".join(runtime_lines)
|
||||
|
||||
_write_artifact("child_bootstrap_paths.txt", path_dump)
|
||||
_write_artifact("child_import_trace.txt", comfy_module_dump)
|
||||
_write_artifact("child_dependency_dump.txt", runtime_report)
|
||||
logger.warning("][ Conda sealed runtime probe executed")
|
||||
logger.warning("][ conda python executable: %s", python_exe)
|
||||
logger.warning(
|
||||
"][ conda dependency origins: xarray=%s cfgrib=%s eccodes=%s",
|
||||
getattr(xr, "__file__", "<missing>"),
|
||||
getattr(cfgrib, "__file__", "<missing>"),
|
||||
getattr(eccodes, "__file__", "<missing>"),
|
||||
)
|
||||
|
||||
return (
|
||||
path_dump,
|
||||
runtime_report,
|
||||
saw_comfy_root,
|
||||
imported_comfy_wrapper,
|
||||
comfy_module_dump,
|
||||
python_exe,
|
||||
saw_user_site,
|
||||
)
|
||||
|
||||
|
||||
class OpenWeatherDatasetNode:
|
||||
RETURN_TYPES = ("FLOAT", "STRING", "STRING")
|
||||
RETURN_NAMES = ("sum_value", "grib_path", "dependency_report")
|
||||
FUNCTION = "open_dataset"
|
||||
CATEGORY = "PyIsolated/SealedWorker"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
|
||||
return {"required": {}}
|
||||
|
||||
def open_dataset(self) -> tuple[float, str, str]:
|
||||
import eccodes
|
||||
import xarray as xr
|
||||
|
||||
artifact_dir = _artifact_dir()
|
||||
if artifact_dir is None:
|
||||
artifact_dir = Path(os.environ.get("HOME", ".")) / "pyisolate_artifacts"
|
||||
artifact_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
grib_path = artifact_dir / "toolkit_weather_fixture.grib2"
|
||||
|
||||
gid = eccodes.codes_grib_new_from_samples("GRIB2")
|
||||
for key, value in [
|
||||
("gridType", "regular_ll"),
|
||||
("Nx", 2),
|
||||
("Ny", 2),
|
||||
("latitudeOfFirstGridPointInDegrees", 1.0),
|
||||
("longitudeOfFirstGridPointInDegrees", 0.0),
|
||||
("latitudeOfLastGridPointInDegrees", 0.0),
|
||||
("longitudeOfLastGridPointInDegrees", 1.0),
|
||||
("iDirectionIncrementInDegrees", 1.0),
|
||||
("jDirectionIncrementInDegrees", 1.0),
|
||||
("jScansPositively", 0),
|
||||
("shortName", "t"),
|
||||
("typeOfLevel", "surface"),
|
||||
("level", 0),
|
||||
("date", 20260315),
|
||||
("time", 0),
|
||||
("step", 0),
|
||||
]:
|
||||
eccodes.codes_set(gid, key, value)
|
||||
|
||||
eccodes.codes_set_values(gid, [1.0, 2.0, 3.0, 4.0])
|
||||
with grib_path.open("wb") as handle:
|
||||
eccodes.codes_write(gid, handle)
|
||||
eccodes.codes_release(gid)
|
||||
|
||||
dataset = xr.open_dataset(grib_path, engine="cfgrib")
|
||||
sum_value = float(dataset["t"].sum().item())
|
||||
dependency_report = "\n".join(
|
||||
[
|
||||
f"dataset_sum={sum_value}",
|
||||
f"grib_path={grib_path}",
|
||||
"xarray_engine=cfgrib",
|
||||
]
|
||||
)
|
||||
_write_artifact("weather_dependency_report.txt", dependency_report)
|
||||
logger.warning("][ cfgrib import ok")
|
||||
logger.warning("][ xarray open_dataset engine=cfgrib path=%s", grib_path)
|
||||
logger.warning("][ conda weather dataset sum=%s", sum_value)
|
||||
return sum_value, str(grib_path), dependency_report
|
||||
|
||||
|
||||
class EchoLatentNode:
|
||||
RETURN_TYPES = ("LATENT", "BOOLEAN")
|
||||
RETURN_NAMES = ("latent", "saw_json_tensor")
|
||||
FUNCTION = "echo_latent"
|
||||
CATEGORY = "PyIsolated/SealedWorker"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
|
||||
return {"required": {"latent": ("LATENT",)}}
|
||||
|
||||
def echo_latent(self, latent: Any) -> tuple[Any, bool]:
|
||||
saw_json_tensor = _contains_tensor_marker(latent)
|
||||
logger.warning("][ conda latent echo json_marker=%s", saw_json_tensor)
|
||||
return latent, saw_json_tensor
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CondaSealedRuntimeProbe": InspectRuntimeNode,
|
||||
"CondaSealedOpenWeatherDataset": OpenWeatherDatasetNode,
|
||||
"CondaSealedLatentEcho": EchoLatentNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CondaSealedRuntimeProbe": "Conda Sealed Runtime Probe",
|
||||
"CondaSealedOpenWeatherDataset": "Conda Sealed Open Weather Dataset",
|
||||
"CondaSealedLatentEcho": "Conda Sealed Latent Echo",
|
||||
}
|
||||
13
tests/isolation/conda_sealed_worker/pyproject.toml
Normal file
13
tests/isolation/conda_sealed_worker/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "comfyui-toolkit-conda-sealed-worker"
|
||||
version = "0.1.0"
|
||||
dependencies = ["xarray", "cfgrib"]
|
||||
|
||||
[tool.comfy.isolation]
|
||||
can_isolate = true
|
||||
share_torch = false
|
||||
package_manager = "conda"
|
||||
execution_model = "sealed_worker"
|
||||
standalone = true
|
||||
conda_channels = ["conda-forge"]
|
||||
conda_dependencies = ["eccodes", "cfgrib"]
|
||||
7
tests/isolation/internal_probe_host_policy.toml
Normal file
7
tests/isolation/internal_probe_host_policy.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
[tool.comfy.host]
|
||||
sandbox_mode = "required"
|
||||
allow_network = false
|
||||
writable_paths = [
|
||||
"/dev/shm",
|
||||
"/home/johnj/ComfyUI/output",
|
||||
]
|
||||
6
tests/isolation/internal_probe_node/__init__.py
Normal file
6
tests/isolation/internal_probe_node/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .probe_nodes import (
|
||||
NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS,
|
||||
NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS,
|
||||
)
|
||||
|
||||
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
||||
75
tests/isolation/internal_probe_node/probe_nodes.py
Normal file
75
tests/isolation/internal_probe_node/probe_nodes.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class InternalIsolationProbeImage:
|
||||
CATEGORY = "tests/isolation"
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
def run(self):
|
||||
from comfy_api.latest import UI
|
||||
import torch
|
||||
|
||||
image = torch.zeros((1, 2, 2, 3), dtype=torch.float32)
|
||||
image[:, :, :, 0] = 1.0
|
||||
ui = UI.PreviewImage(image)
|
||||
return {"ui": ui.as_dict(), "result": ()}
|
||||
|
||||
|
||||
class InternalIsolationProbeAudio:
|
||||
CATEGORY = "tests/isolation"
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
def run(self):
|
||||
from comfy_api.latest import UI
|
||||
import torch
|
||||
|
||||
waveform = torch.zeros((1, 1, 32), dtype=torch.float32)
|
||||
audio = {"waveform": waveform, "sample_rate": 44100}
|
||||
ui = UI.PreviewAudio(audio)
|
||||
return {"ui": ui.as_dict(), "result": ()}
|
||||
|
||||
|
||||
class InternalIsolationProbeUI3D:
|
||||
CATEGORY = "tests/isolation"
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
def run(self):
|
||||
from comfy_api.latest import UI
|
||||
import torch
|
||||
|
||||
bg_image = torch.zeros((1, 2, 2, 3), dtype=torch.float32)
|
||||
bg_image[:, :, :, 1] = 1.0
|
||||
camera_info = {"distance": 1.0}
|
||||
ui = UI.PreviewUI3D("internal_probe_preview.obj", camera_info, bg_image=bg_image)
|
||||
return {"ui": ui.as_dict(), "result": ()}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"InternalIsolationProbeImage": InternalIsolationProbeImage,
|
||||
"InternalIsolationProbeAudio": InternalIsolationProbeAudio,
|
||||
"InternalIsolationProbeUI3D": InternalIsolationProbeUI3D,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"InternalIsolationProbeImage": "Internal Isolation Probe Image",
|
||||
"InternalIsolationProbeAudio": "Internal Isolation Probe Audio",
|
||||
"InternalIsolationProbeUI3D": "Internal Isolation Probe UI3D",
|
||||
}
|
||||
955
tests/isolation/singleton_boundary_helpers.py
Normal file
955
tests/isolation/singleton_boundary_helpers.py
Normal file
@@ -0,0 +1,955 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
COMFYUI_ROOT = Path(__file__).resolve().parents[2]
|
||||
UV_SEALED_WORKER_MODULE = COMFYUI_ROOT / "tests" / "isolation" / "uv_sealed_worker" / "__init__.py"
|
||||
FORBIDDEN_MINIMAL_SEALED_MODULES = (
|
||||
"torch",
|
||||
"folder_paths",
|
||||
"comfy.utils",
|
||||
"comfy.model_management",
|
||||
"main",
|
||||
"comfy.isolation.extension_wrapper",
|
||||
)
|
||||
FORBIDDEN_SEALED_SINGLETON_MODULES = (
|
||||
"torch",
|
||||
"folder_paths",
|
||||
"comfy.utils",
|
||||
"comfy_execution.progress",
|
||||
)
|
||||
FORBIDDEN_EXACT_SMALL_PROXY_MODULES = FORBIDDEN_SEALED_SINGLETON_MODULES
|
||||
FORBIDDEN_MODEL_MANAGEMENT_MODULES = (
|
||||
"comfy.model_management",
|
||||
)
|
||||
|
||||
|
||||
def _load_module_from_path(module_name: str, module_path: Path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise RuntimeError(f"unable to build import spec for {module_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
except Exception:
|
||||
sys.modules.pop(module_name, None)
|
||||
raise
|
||||
return module
|
||||
|
||||
|
||||
def matching_modules(prefixes: tuple[str, ...], modules: set[str]) -> list[str]:
|
||||
return sorted(
|
||||
module_name
|
||||
for module_name in modules
|
||||
if any(
|
||||
module_name == prefix or module_name.startswith(f"{prefix}.")
|
||||
for prefix in prefixes
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _load_helper_proxy_service() -> Any | None:
|
||||
try:
|
||||
from comfy.isolation.proxies.helper_proxies import HelperProxiesService
|
||||
except (ImportError, AttributeError):
|
||||
return None
|
||||
return HelperProxiesService
|
||||
|
||||
|
||||
def _load_model_management_proxy() -> Any | None:
|
||||
try:
|
||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||
except (ImportError, AttributeError):
|
||||
return None
|
||||
return ModelManagementProxy
|
||||
|
||||
|
||||
async def _capture_minimal_sealed_worker_imports() -> dict[str, object]:
|
||||
from pyisolate.sealed import SealedNodeExtension
|
||||
|
||||
module_name = "tests.isolation.uv_sealed_worker_boundary_probe"
|
||||
before = set(sys.modules)
|
||||
extension = SealedNodeExtension()
|
||||
module = _load_module_from_path(module_name, UV_SEALED_WORKER_MODULE)
|
||||
try:
|
||||
await extension.on_module_loaded(module)
|
||||
node_list = await extension.list_nodes()
|
||||
node_details = await extension.get_node_details("UVSealedRuntimeProbe")
|
||||
imported = set(sys.modules) - before
|
||||
return {
|
||||
"mode": "minimal_sealed_worker",
|
||||
"node_names": sorted(node_list),
|
||||
"runtime_probe_function": node_details["function"],
|
||||
"modules": sorted(imported),
|
||||
"forbidden_matches": matching_modules(FORBIDDEN_MINIMAL_SEALED_MODULES, imported),
|
||||
}
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
|
||||
def capture_minimal_sealed_worker_imports() -> dict[str, object]:
|
||||
return asyncio.run(_capture_minimal_sealed_worker_imports())
|
||||
|
||||
|
||||
class FakeSingletonCaller:
|
||||
def __init__(self, methods: dict[str, Any], calls: list[dict[str, Any]], object_id: str):
|
||||
self._methods = methods
|
||||
self._calls = calls
|
||||
self._object_id = object_id
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name not in self._methods:
|
||||
raise AttributeError(name)
|
||||
|
||||
async def method(*args: Any, **kwargs: Any) -> Any:
|
||||
self._calls.append(
|
||||
{
|
||||
"object_id": self._object_id,
|
||||
"method": name,
|
||||
"args": list(args),
|
||||
"kwargs": dict(kwargs),
|
||||
}
|
||||
)
|
||||
result = self._methods[name]
|
||||
return result(*args, **kwargs) if callable(result) else result
|
||||
|
||||
return method
|
||||
|
||||
|
||||
class FakeSingletonRPC:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self._device = {"__pyisolate_torch_device__": "cpu"}
|
||||
self._services: dict[str, dict[str, Any]] = {
|
||||
"FolderPathsProxy": {
|
||||
"rpc_get_models_dir": lambda: "/sandbox/models",
|
||||
"rpc_get_folder_names_and_paths": lambda: {
|
||||
"checkpoints": {
|
||||
"paths": ["/sandbox/models/checkpoints"],
|
||||
"extensions": [".ckpt", ".safetensors"],
|
||||
}
|
||||
},
|
||||
"rpc_get_extension_mimetypes_cache": lambda: {"webp": "image"},
|
||||
"rpc_get_filename_list_cache": lambda: {},
|
||||
"rpc_get_temp_directory": lambda: "/sandbox/temp",
|
||||
"rpc_get_input_directory": lambda: "/sandbox/input",
|
||||
"rpc_get_output_directory": lambda: "/sandbox/output",
|
||||
"rpc_get_user_directory": lambda: "/sandbox/user",
|
||||
"rpc_get_annotated_filepath": self._get_annotated_filepath,
|
||||
"rpc_exists_annotated_filepath": lambda _name: False,
|
||||
"rpc_add_model_folder_path": lambda *_args, **_kwargs: None,
|
||||
"rpc_get_folder_paths": lambda folder_name: [f"/sandbox/models/{folder_name}"],
|
||||
"rpc_get_filename_list": lambda folder_name: [f"{folder_name}_fixture.safetensors"],
|
||||
"rpc_get_full_path": lambda folder_name, filename: f"/sandbox/models/{folder_name}/{filename}",
|
||||
},
|
||||
"UtilsProxy": {
|
||||
"progress_bar_hook": lambda value, total, preview=None, node_id=None: {
|
||||
"value": value,
|
||||
"total": total,
|
||||
"preview": preview,
|
||||
"node_id": node_id,
|
||||
}
|
||||
},
|
||||
"ProgressProxy": {
|
||||
"rpc_set_progress": lambda value, max_value, node_id=None, image=None: {
|
||||
"value": value,
|
||||
"max_value": max_value,
|
||||
"node_id": node_id,
|
||||
"image": image,
|
||||
}
|
||||
},
|
||||
"HelperProxiesService": {
|
||||
"rpc_restore_input_types": lambda raw: raw,
|
||||
},
|
||||
"ModelManagementProxy": {
|
||||
"rpc_call": self._model_management_rpc_call,
|
||||
},
|
||||
}
|
||||
|
||||
def _model_management_rpc_call(self, method_name: str, args: Any = None, kwargs: Any = None) -> Any:
|
||||
if method_name == "get_torch_device":
|
||||
return self._device
|
||||
elif method_name == "get_torch_device_name":
|
||||
return "cpu"
|
||||
elif method_name == "get_free_memory":
|
||||
return 34359738368
|
||||
raise AssertionError(f"unexpected model_management method {method_name}")
|
||||
|
||||
@staticmethod
|
||||
def _get_annotated_filepath(name: str, default_dir: str | None = None) -> str:
|
||||
if name.endswith("[output]"):
|
||||
return f"/sandbox/output/{name[:-8]}"
|
||||
if name.endswith("[input]"):
|
||||
return f"/sandbox/input/{name[:-7]}"
|
||||
if name.endswith("[temp]"):
|
||||
return f"/sandbox/temp/{name[:-6]}"
|
||||
base_dir = default_dir or "/sandbox/input"
|
||||
return f"{base_dir}/{name}"
|
||||
|
||||
def create_caller(self, cls: Any, object_id: str):
|
||||
methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id))
|
||||
if methods is None:
|
||||
raise KeyError(object_id)
|
||||
return FakeSingletonCaller(methods, self.calls, object_id)
|
||||
|
||||
|
||||
def _clear_proxy_rpcs() -> None:
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
|
||||
FolderPathsProxy.clear_rpc()
|
||||
ProgressProxy.clear_rpc()
|
||||
UtilsProxy.clear_rpc()
|
||||
helper_proxy_service = _load_helper_proxy_service()
|
||||
if helper_proxy_service is not None:
|
||||
helper_proxy_service.clear_rpc()
|
||||
model_management_proxy = _load_model_management_proxy()
|
||||
if model_management_proxy is not None and hasattr(model_management_proxy, "clear_rpc"):
|
||||
model_management_proxy.clear_rpc()
|
||||
|
||||
|
||||
def prepare_sealed_singleton_proxies(fake_rpc: FakeSingletonRPC) -> None:
|
||||
os.environ["PYISOLATE_CHILD"] = "1"
|
||||
os.environ["PYISOLATE_IMPORT_TORCH"] = "0"
|
||||
_clear_proxy_rpcs()
|
||||
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
|
||||
FolderPathsProxy.set_rpc(fake_rpc)
|
||||
ProgressProxy.set_rpc(fake_rpc)
|
||||
UtilsProxy.set_rpc(fake_rpc)
|
||||
helper_proxy_service = _load_helper_proxy_service()
|
||||
if helper_proxy_service is not None:
|
||||
helper_proxy_service.set_rpc(fake_rpc)
|
||||
model_management_proxy = _load_model_management_proxy()
|
||||
if model_management_proxy is not None and hasattr(model_management_proxy, "set_rpc"):
|
||||
model_management_proxy.set_rpc(fake_rpc)
|
||||
|
||||
|
||||
def reset_forbidden_singleton_modules() -> None:
|
||||
for module_name in (
|
||||
"folder_paths",
|
||||
"comfy.utils",
|
||||
"comfy_execution.progress",
|
||||
):
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
|
||||
class FakeExactRelayCaller:
|
||||
def __init__(self, methods: dict[str, Any], transcripts: list[dict[str, Any]], object_id: str):
|
||||
self._methods = methods
|
||||
self._transcripts = transcripts
|
||||
self._object_id = object_id
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name not in self._methods:
|
||||
raise AttributeError(name)
|
||||
|
||||
async def method(*args: Any, **kwargs: Any) -> Any:
|
||||
self._transcripts.append(
|
||||
{
|
||||
"phase": "child_call",
|
||||
"object_id": self._object_id,
|
||||
"method": name,
|
||||
"args": list(args),
|
||||
"kwargs": dict(kwargs),
|
||||
}
|
||||
)
|
||||
impl = self._methods[name]
|
||||
self._transcripts.append(
|
||||
{
|
||||
"phase": "host_invocation",
|
||||
"object_id": self._object_id,
|
||||
"method": name,
|
||||
"target": impl["target"],
|
||||
"args": list(args),
|
||||
"kwargs": dict(kwargs),
|
||||
}
|
||||
)
|
||||
result = impl["result"](*args, **kwargs) if callable(impl["result"]) else impl["result"]
|
||||
self._transcripts.append(
|
||||
{
|
||||
"phase": "result",
|
||||
"object_id": self._object_id,
|
||||
"method": name,
|
||||
"result": result,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
return method
|
||||
|
||||
|
||||
class FakeExactRelayRPC:
|
||||
def __init__(self) -> None:
|
||||
self.transcripts: list[dict[str, Any]] = []
|
||||
self._device = {"__pyisolate_torch_device__": "cpu"}
|
||||
self._services: dict[str, dict[str, Any]] = {
|
||||
"FolderPathsProxy": {
|
||||
"rpc_get_models_dir": {
|
||||
"target": "folder_paths.models_dir",
|
||||
"result": "/sandbox/models",
|
||||
},
|
||||
"rpc_get_temp_directory": {
|
||||
"target": "folder_paths.get_temp_directory",
|
||||
"result": "/sandbox/temp",
|
||||
},
|
||||
"rpc_get_input_directory": {
|
||||
"target": "folder_paths.get_input_directory",
|
||||
"result": "/sandbox/input",
|
||||
},
|
||||
"rpc_get_output_directory": {
|
||||
"target": "folder_paths.get_output_directory",
|
||||
"result": "/sandbox/output",
|
||||
},
|
||||
"rpc_get_user_directory": {
|
||||
"target": "folder_paths.get_user_directory",
|
||||
"result": "/sandbox/user",
|
||||
},
|
||||
"rpc_get_folder_names_and_paths": {
|
||||
"target": "folder_paths.folder_names_and_paths",
|
||||
"result": {
|
||||
"checkpoints": {
|
||||
"paths": ["/sandbox/models/checkpoints"],
|
||||
"extensions": [".ckpt", ".safetensors"],
|
||||
}
|
||||
},
|
||||
},
|
||||
"rpc_get_extension_mimetypes_cache": {
|
||||
"target": "folder_paths.extension_mimetypes_cache",
|
||||
"result": {"webp": "image"},
|
||||
},
|
||||
"rpc_get_filename_list_cache": {
|
||||
"target": "folder_paths.filename_list_cache",
|
||||
"result": {},
|
||||
},
|
||||
"rpc_get_annotated_filepath": {
|
||||
"target": "folder_paths.get_annotated_filepath",
|
||||
"result": lambda name, default_dir=None: FakeSingletonRPC._get_annotated_filepath(name, default_dir),
|
||||
},
|
||||
"rpc_exists_annotated_filepath": {
|
||||
"target": "folder_paths.exists_annotated_filepath",
|
||||
"result": False,
|
||||
},
|
||||
"rpc_add_model_folder_path": {
|
||||
"target": "folder_paths.add_model_folder_path",
|
||||
"result": None,
|
||||
},
|
||||
"rpc_get_folder_paths": {
|
||||
"target": "folder_paths.get_folder_paths",
|
||||
"result": lambda folder_name: [f"/sandbox/models/{folder_name}"],
|
||||
},
|
||||
"rpc_get_filename_list": {
|
||||
"target": "folder_paths.get_filename_list",
|
||||
"result": lambda folder_name: [f"{folder_name}_fixture.safetensors"],
|
||||
},
|
||||
"rpc_get_full_path": {
|
||||
"target": "folder_paths.get_full_path",
|
||||
"result": lambda folder_name, filename: f"/sandbox/models/{folder_name}/{filename}",
|
||||
},
|
||||
},
|
||||
"UtilsProxy": {
|
||||
"progress_bar_hook": {
|
||||
"target": "comfy.utils.PROGRESS_BAR_HOOK",
|
||||
"result": lambda value, total, preview=None, node_id=None: {
|
||||
"value": value,
|
||||
"total": total,
|
||||
"preview": preview,
|
||||
"node_id": node_id,
|
||||
},
|
||||
},
|
||||
},
|
||||
"ProgressProxy": {
|
||||
"rpc_set_progress": {
|
||||
"target": "comfy_execution.progress.get_progress_state().update_progress",
|
||||
"result": None,
|
||||
},
|
||||
},
|
||||
"HelperProxiesService": {
|
||||
"rpc_restore_input_types": {
|
||||
"target": "comfy.isolation.proxies.helper_proxies.restore_input_types",
|
||||
"result": lambda raw: raw,
|
||||
}
|
||||
},
|
||||
"ModelManagementProxy": {
|
||||
"rpc_call": {
|
||||
"target": "comfy.model_management.*",
|
||||
"result": self._model_management_rpc_call,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _model_management_rpc_call(self, method_name: str, args: Any = None, kwargs: Any = None) -> Any:
|
||||
device = {"__pyisolate_torch_device__": "cpu"}
|
||||
if method_name == "get_torch_device":
|
||||
return device
|
||||
elif method_name == "get_torch_device_name":
|
||||
return "cpu"
|
||||
elif method_name == "get_free_memory":
|
||||
return 34359738368
|
||||
raise AssertionError(f"unexpected exact-relay method {method_name}")
|
||||
|
||||
def create_caller(self, cls: Any, object_id: str):
|
||||
methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id))
|
||||
if methods is None:
|
||||
raise KeyError(object_id)
|
||||
return FakeExactRelayCaller(methods, self.transcripts, object_id)
|
||||
|
||||
|
||||
def capture_exact_small_proxy_relay() -> dict[str, object]:
|
||||
reset_forbidden_singleton_modules()
|
||||
fake_rpc = FakeExactRelayRPC()
|
||||
previous_child = os.environ.get("PYISOLATE_CHILD")
|
||||
previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH")
|
||||
try:
|
||||
prepare_sealed_singleton_proxies(fake_rpc)
|
||||
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.helper_proxies import restore_input_types
|
||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
|
||||
folder_proxy = FolderPathsProxy()
|
||||
utils_proxy = UtilsProxy()
|
||||
progress_proxy = ProgressProxy()
|
||||
before = set(sys.modules)
|
||||
|
||||
restored = restore_input_types(
|
||||
{
|
||||
"required": {
|
||||
"image": {"__pyisolate_any_type__": True, "value": "*"},
|
||||
}
|
||||
}
|
||||
)
|
||||
folder_path = folder_proxy.get_annotated_filepath("demo.png[input]")
|
||||
models_dir = folder_proxy.models_dir
|
||||
folder_names_and_paths = folder_proxy.folder_names_and_paths
|
||||
asyncio.run(utils_proxy.progress_bar_hook(2, 5, node_id="node-17"))
|
||||
progress_proxy.set_progress(1.5, 5.0, node_id="node-17")
|
||||
|
||||
imported = set(sys.modules) - before
|
||||
return {
|
||||
"mode": "exact_small_proxy_relay",
|
||||
"folder_path": folder_path,
|
||||
"models_dir": models_dir,
|
||||
"folder_names_and_paths": folder_names_and_paths,
|
||||
"restored_any_type": str(restored["required"]["image"]),
|
||||
"transcripts": fake_rpc.transcripts,
|
||||
"modules": sorted(imported),
|
||||
"forbidden_matches": matching_modules(FORBIDDEN_EXACT_SMALL_PROXY_MODULES, imported),
|
||||
}
|
||||
finally:
|
||||
_clear_proxy_rpcs()
|
||||
if previous_child is None:
|
||||
os.environ.pop("PYISOLATE_CHILD", None)
|
||||
else:
|
||||
os.environ["PYISOLATE_CHILD"] = previous_child
|
||||
if previous_import_torch is None:
|
||||
os.environ.pop("PYISOLATE_IMPORT_TORCH", None)
|
||||
else:
|
||||
os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch
|
||||
|
||||
|
||||
class FakeModelManagementExactRelayRPC:
|
||||
def __init__(self) -> None:
|
||||
self.transcripts: list[dict[str, object]] = []
|
||||
self._device = {"__pyisolate_torch_device__": "cpu"}
|
||||
self._services: dict[str, dict[str, Any]] = {
|
||||
"ModelManagementProxy": {
|
||||
"rpc_call": self._rpc_call,
|
||||
}
|
||||
}
|
||||
|
||||
def create_caller(self, cls: Any, object_id: str):
|
||||
methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id))
|
||||
if methods is None:
|
||||
raise KeyError(object_id)
|
||||
return _ModelManagementExactRelayCaller(methods)
|
||||
|
||||
def _rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any:
|
||||
self.transcripts.append(
|
||||
{
|
||||
"phase": "child_call",
|
||||
"object_id": "ModelManagementProxy",
|
||||
"method": method_name,
|
||||
"args": _json_safe(args),
|
||||
"kwargs": _json_safe(kwargs),
|
||||
}
|
||||
)
|
||||
target = f"comfy.model_management.{method_name}"
|
||||
self.transcripts.append(
|
||||
{
|
||||
"phase": "host_invocation",
|
||||
"object_id": "ModelManagementProxy",
|
||||
"method": method_name,
|
||||
"target": target,
|
||||
"args": _json_safe(args),
|
||||
"kwargs": _json_safe(kwargs),
|
||||
}
|
||||
)
|
||||
if method_name == "get_torch_device":
|
||||
result = self._device
|
||||
elif method_name == "get_torch_device_name":
|
||||
result = "cpu"
|
||||
elif method_name == "get_free_memory":
|
||||
result = 34359738368
|
||||
else:
|
||||
raise AssertionError(f"unexpected exact-relay method {method_name}")
|
||||
self.transcripts.append(
|
||||
{
|
||||
"phase": "result",
|
||||
"object_id": "ModelManagementProxy",
|
||||
"method": method_name,
|
||||
"result": _json_safe(result),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class _ModelManagementExactRelayCaller:
|
||||
def __init__(self, methods: dict[str, Any]):
|
||||
self._methods = methods
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name not in self._methods:
|
||||
raise AttributeError(name)
|
||||
|
||||
async def method(*args: Any, **kwargs: Any) -> Any:
|
||||
impl = self._methods[name]
|
||||
return impl(*args, **kwargs) if callable(impl) else impl
|
||||
|
||||
return method
|
||||
|
||||
|
||||
def _json_safe(value: Any) -> Any:
|
||||
if callable(value):
|
||||
return f"<callable {getattr(value, '__name__', 'anonymous')}>"
|
||||
if isinstance(value, tuple):
|
||||
return [_json_safe(item) for item in value]
|
||||
if isinstance(value, list):
|
||||
return [_json_safe(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: _json_safe(inner) for key, inner in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def capture_model_management_exact_relay() -> dict[str, object]:
|
||||
for module_name in FORBIDDEN_MODEL_MANAGEMENT_MODULES:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
fake_rpc = FakeModelManagementExactRelayRPC()
|
||||
previous_child = os.environ.get("PYISOLATE_CHILD")
|
||||
previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH")
|
||||
try:
|
||||
os.environ["PYISOLATE_CHILD"] = "1"
|
||||
os.environ["PYISOLATE_IMPORT_TORCH"] = "0"
|
||||
|
||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||
|
||||
if hasattr(ModelManagementProxy, "clear_rpc"):
|
||||
ModelManagementProxy.clear_rpc()
|
||||
if hasattr(ModelManagementProxy, "set_rpc"):
|
||||
ModelManagementProxy.set_rpc(fake_rpc)
|
||||
|
||||
proxy = ModelManagementProxy()
|
||||
before = set(sys.modules)
|
||||
device = proxy.get_torch_device()
|
||||
device_name = proxy.get_torch_device_name(device)
|
||||
free_memory = proxy.get_free_memory(device)
|
||||
imported = set(sys.modules) - before
|
||||
return {
|
||||
"mode": "model_management_exact_relay",
|
||||
"device": str(device),
|
||||
"device_type": getattr(device, "type", None),
|
||||
"device_name": device_name,
|
||||
"free_memory": free_memory,
|
||||
"transcripts": fake_rpc.transcripts,
|
||||
"modules": sorted(imported),
|
||||
"forbidden_matches": matching_modules(FORBIDDEN_MODEL_MANAGEMENT_MODULES, imported),
|
||||
}
|
||||
finally:
|
||||
model_management_proxy = _load_model_management_proxy()
|
||||
if model_management_proxy is not None and hasattr(model_management_proxy, "clear_rpc"):
|
||||
model_management_proxy.clear_rpc()
|
||||
if previous_child is None:
|
||||
os.environ.pop("PYISOLATE_CHILD", None)
|
||||
else:
|
||||
os.environ["PYISOLATE_CHILD"] = previous_child
|
||||
if previous_import_torch is None:
|
||||
os.environ.pop("PYISOLATE_IMPORT_TORCH", None)
|
||||
else:
|
||||
os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch
|
||||
|
||||
|
||||
FORBIDDEN_PROMPT_WEB_MODULES = (
|
||||
"server",
|
||||
"aiohttp",
|
||||
"comfy.isolation.extension_wrapper",
|
||||
)
|
||||
FORBIDDEN_EXACT_BOOTSTRAP_MODULES = (
|
||||
"comfy.isolation.adapter",
|
||||
"folder_paths",
|
||||
"comfy.utils",
|
||||
"comfy.model_management",
|
||||
"server",
|
||||
"main",
|
||||
"comfy.isolation.extension_wrapper",
|
||||
)
|
||||
|
||||
|
||||
class _PromptServiceExactRelayCaller:
|
||||
def __init__(self, methods: dict[str, Any], transcripts: list[dict[str, Any]], object_id: str):
|
||||
self._methods = methods
|
||||
self._transcripts = transcripts
|
||||
self._object_id = object_id
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name not in self._methods:
|
||||
raise AttributeError(name)
|
||||
|
||||
async def method(*args: Any, **kwargs: Any) -> Any:
|
||||
self._transcripts.append(
|
||||
{
|
||||
"phase": "child_call",
|
||||
"object_id": self._object_id,
|
||||
"method": name,
|
||||
"args": _json_safe(args),
|
||||
"kwargs": _json_safe(kwargs),
|
||||
}
|
||||
)
|
||||
impl = self._methods[name]
|
||||
self._transcripts.append(
|
||||
{
|
||||
"phase": "host_invocation",
|
||||
"object_id": self._object_id,
|
||||
"method": name,
|
||||
"target": impl["target"],
|
||||
"args": _json_safe(args),
|
||||
"kwargs": _json_safe(kwargs),
|
||||
}
|
||||
)
|
||||
result = impl["result"](*args, **kwargs) if callable(impl["result"]) else impl["result"]
|
||||
self._transcripts.append(
|
||||
{
|
||||
"phase": "result",
|
||||
"object_id": self._object_id,
|
||||
"method": name,
|
||||
"result": _json_safe(result),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
return method
|
||||
|
||||
|
||||
class FakePromptWebRPC:
|
||||
def __init__(self) -> None:
|
||||
self.transcripts: list[dict[str, Any]] = []
|
||||
self._services = {
|
||||
"PromptServerService": {
|
||||
"ui_send_progress_text": {
|
||||
"target": "server.PromptServer.instance.send_progress_text",
|
||||
"result": None,
|
||||
},
|
||||
"register_route_rpc": {
|
||||
"target": "server.PromptServer.instance.routes.add_route",
|
||||
"result": None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def create_caller(self, cls: Any, object_id: str):
|
||||
methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id))
|
||||
if methods is None:
|
||||
raise KeyError(object_id)
|
||||
return _PromptServiceExactRelayCaller(methods, self.transcripts, object_id)
|
||||
|
||||
|
||||
class FakeWebDirectoryProxy:
|
||||
def __init__(self, transcripts: list[dict[str, Any]]):
|
||||
self._transcripts = transcripts
|
||||
|
||||
def get_web_file(self, extension_name: str, relative_path: str) -> dict[str, Any]:
|
||||
self._transcripts.append(
|
||||
{
|
||||
"phase": "child_call",
|
||||
"object_id": "WebDirectoryProxy",
|
||||
"method": "get_web_file",
|
||||
"args": [extension_name, relative_path],
|
||||
"kwargs": {},
|
||||
}
|
||||
)
|
||||
self._transcripts.append(
|
||||
{
|
||||
"phase": "host_invocation",
|
||||
"object_id": "WebDirectoryProxy",
|
||||
"method": "get_web_file",
|
||||
"target": "comfy.isolation.proxies.web_directory_proxy.WebDirectoryProxy.get_web_file",
|
||||
"args": [extension_name, relative_path],
|
||||
"kwargs": {},
|
||||
}
|
||||
)
|
||||
result = {
|
||||
"content": "Y29uc29sZS5sb2coJ2RlbycpOw==",
|
||||
"content_type": "application/javascript",
|
||||
}
|
||||
self._transcripts.append(
|
||||
{
|
||||
"phase": "result",
|
||||
"object_id": "WebDirectoryProxy",
|
||||
"method": "get_web_file",
|
||||
"result": result,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def capture_prompt_web_exact_relay() -> dict[str, object]:
|
||||
for module_name in FORBIDDEN_PROMPT_WEB_MODULES:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
fake_rpc = FakePromptWebRPC()
|
||||
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
|
||||
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryCache
|
||||
|
||||
PromptServerStub.set_rpc(fake_rpc)
|
||||
stub = PromptServerStub()
|
||||
cache = WebDirectoryCache()
|
||||
cache.register_proxy("demo_ext", FakeWebDirectoryProxy(fake_rpc.transcripts))
|
||||
|
||||
before = set(sys.modules)
|
||||
|
||||
def demo_handler(_request):
|
||||
return {"ok": True}
|
||||
|
||||
stub.send_progress_text("hello", "node-17")
|
||||
stub.routes.get("/demo")(demo_handler)
|
||||
web_file = cache.get_file("demo_ext", "js/app.js")
|
||||
imported = set(sys.modules) - before
|
||||
return {
|
||||
"mode": "prompt_web_exact_relay",
|
||||
"web_file": {
|
||||
"content_type": web_file["content_type"] if web_file else None,
|
||||
"content": web_file["content"].decode("utf-8") if web_file else None,
|
||||
},
|
||||
"transcripts": fake_rpc.transcripts,
|
||||
"modules": sorted(imported),
|
||||
"forbidden_matches": matching_modules(FORBIDDEN_PROMPT_WEB_MODULES, imported),
|
||||
}
|
||||
|
||||
|
||||
class FakeExactBootstrapRPC:
|
||||
def __init__(self) -> None:
|
||||
self.transcripts: list[dict[str, Any]] = []
|
||||
self._device = {"__pyisolate_torch_device__": "cpu"}
|
||||
self._services: dict[str, dict[str, Any]] = {
|
||||
"FolderPathsProxy": FakeExactRelayRPC()._services["FolderPathsProxy"],
|
||||
"HelperProxiesService": FakeExactRelayRPC()._services["HelperProxiesService"],
|
||||
"ProgressProxy": FakeExactRelayRPC()._services["ProgressProxy"],
|
||||
"UtilsProxy": FakeExactRelayRPC()._services["UtilsProxy"],
|
||||
"PromptServerService": {
|
||||
"ui_send_sync": {
|
||||
"target": "server.PromptServer.instance.send_sync",
|
||||
"result": None,
|
||||
},
|
||||
"ui_send": {
|
||||
"target": "server.PromptServer.instance.send",
|
||||
"result": None,
|
||||
},
|
||||
"ui_send_progress_text": {
|
||||
"target": "server.PromptServer.instance.send_progress_text",
|
||||
"result": None,
|
||||
},
|
||||
"register_route_rpc": {
|
||||
"target": "server.PromptServer.instance.routes.add_route",
|
||||
"result": None,
|
||||
},
|
||||
},
|
||||
"ModelManagementProxy": {
|
||||
"rpc_call": self._rpc_call,
|
||||
},
|
||||
}
|
||||
|
||||
def create_caller(self, cls: Any, object_id: str):
|
||||
methods = self._services.get(object_id) or self._services.get(getattr(cls, "__name__", object_id))
|
||||
if methods is None:
|
||||
raise KeyError(object_id)
|
||||
if object_id == "ModelManagementProxy":
|
||||
return _ModelManagementExactRelayCaller(methods)
|
||||
return _PromptServiceExactRelayCaller(methods, self.transcripts, object_id)
|
||||
|
||||
def _rpc_call(self, method_name: str, args: Any, kwargs: Any) -> Any:
|
||||
self.transcripts.append(
|
||||
{
|
||||
"phase": "child_call",
|
||||
"object_id": "ModelManagementProxy",
|
||||
"method": method_name,
|
||||
"args": _json_safe(args),
|
||||
"kwargs": _json_safe(kwargs),
|
||||
}
|
||||
)
|
||||
self.transcripts.append(
|
||||
{
|
||||
"phase": "host_invocation",
|
||||
"object_id": "ModelManagementProxy",
|
||||
"method": method_name,
|
||||
"target": f"comfy.model_management.{method_name}",
|
||||
"args": _json_safe(args),
|
||||
"kwargs": _json_safe(kwargs),
|
||||
}
|
||||
)
|
||||
result = self._device if method_name == "get_torch_device" else None
|
||||
self.transcripts.append(
|
||||
{
|
||||
"phase": "result",
|
||||
"object_id": "ModelManagementProxy",
|
||||
"method": method_name,
|
||||
"result": _json_safe(result),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def capture_exact_proxy_bootstrap_contract() -> dict[str, object]:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance, set_child_rpc_instance
|
||||
|
||||
from comfy.isolation.adapter import ComfyUIAdapter
|
||||
from comfy.isolation.child_hooks import initialize_child_process
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.helper_proxies import HelperProxiesService
|
||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
|
||||
host_services = sorted(cls.__name__ for cls in ComfyUIAdapter().provide_rpc_services())
|
||||
|
||||
for module_name in FORBIDDEN_EXACT_BOOTSTRAP_MODULES:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
previous_child = os.environ.get("PYISOLATE_CHILD")
|
||||
previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH")
|
||||
os.environ["PYISOLATE_CHILD"] = "1"
|
||||
os.environ["PYISOLATE_IMPORT_TORCH"] = "0"
|
||||
|
||||
_clear_proxy_rpcs()
|
||||
if hasattr(PromptServerStub, "clear_rpc"):
|
||||
PromptServerStub.clear_rpc()
|
||||
else:
|
||||
PromptServerStub._rpc = None # type: ignore[attr-defined]
|
||||
fake_rpc = FakeExactBootstrapRPC()
|
||||
set_child_rpc_instance(fake_rpc)
|
||||
|
||||
before = set(sys.modules)
|
||||
try:
|
||||
initialize_child_process()
|
||||
imported = set(sys.modules) - before
|
||||
matrix = {
|
||||
"base.py": {
|
||||
"bound": get_child_rpc_instance() is fake_rpc,
|
||||
"details": {"child_rpc_instance": get_child_rpc_instance() is fake_rpc},
|
||||
},
|
||||
"folder_paths_proxy.py": {
|
||||
"bound": "FolderPathsProxy" in host_services and FolderPathsProxy._rpc is not None,
|
||||
"details": {"host_service": "FolderPathsProxy" in host_services, "child_rpc": FolderPathsProxy._rpc is not None},
|
||||
},
|
||||
"helper_proxies.py": {
|
||||
"bound": "HelperProxiesService" in host_services and HelperProxiesService._rpc is not None,
|
||||
"details": {"host_service": "HelperProxiesService" in host_services, "child_rpc": HelperProxiesService._rpc is not None},
|
||||
},
|
||||
"model_management_proxy.py": {
|
||||
"bound": "ModelManagementProxy" in host_services and ModelManagementProxy._rpc is not None,
|
||||
"details": {"host_service": "ModelManagementProxy" in host_services, "child_rpc": ModelManagementProxy._rpc is not None},
|
||||
},
|
||||
"progress_proxy.py": {
|
||||
"bound": "ProgressProxy" in host_services and ProgressProxy._rpc is not None,
|
||||
"details": {"host_service": "ProgressProxy" in host_services, "child_rpc": ProgressProxy._rpc is not None},
|
||||
},
|
||||
"prompt_server_impl.py": {
|
||||
"bound": "PromptServerService" in host_services and PromptServerStub._rpc is not None,
|
||||
"details": {"host_service": "PromptServerService" in host_services, "child_rpc": PromptServerStub._rpc is not None},
|
||||
},
|
||||
"utils_proxy.py": {
|
||||
"bound": "UtilsProxy" in host_services and UtilsProxy._rpc is not None,
|
||||
"details": {"host_service": "UtilsProxy" in host_services, "child_rpc": UtilsProxy._rpc is not None},
|
||||
},
|
||||
"web_directory_proxy.py": {
|
||||
"bound": "WebDirectoryProxy" in host_services,
|
||||
"details": {"host_service": "WebDirectoryProxy" in host_services},
|
||||
},
|
||||
}
|
||||
finally:
|
||||
set_child_rpc_instance(None)
|
||||
if previous_child is None:
|
||||
os.environ.pop("PYISOLATE_CHILD", None)
|
||||
else:
|
||||
os.environ["PYISOLATE_CHILD"] = previous_child
|
||||
if previous_import_torch is None:
|
||||
os.environ.pop("PYISOLATE_IMPORT_TORCH", None)
|
||||
else:
|
||||
os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch
|
||||
|
||||
omitted = sorted(name for name, status in matrix.items() if not status["bound"])
|
||||
return {
|
||||
"mode": "exact_proxy_bootstrap_contract",
|
||||
"host_services": host_services,
|
||||
"matrix": matrix,
|
||||
"omitted_proxies": omitted,
|
||||
"modules": sorted(imported),
|
||||
"forbidden_matches": matching_modules(FORBIDDEN_EXACT_BOOTSTRAP_MODULES, imported),
|
||||
}
|
||||
|
||||
def capture_sealed_singleton_imports() -> dict[str, object]:
|
||||
reset_forbidden_singleton_modules()
|
||||
fake_rpc = FakeSingletonRPC()
|
||||
previous_child = os.environ.get("PYISOLATE_CHILD")
|
||||
previous_import_torch = os.environ.get("PYISOLATE_IMPORT_TORCH")
|
||||
before = set(sys.modules)
|
||||
try:
|
||||
prepare_sealed_singleton_proxies(fake_rpc)
|
||||
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.progress_proxy import ProgressProxy
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
|
||||
folder_proxy = FolderPathsProxy()
|
||||
progress_proxy = ProgressProxy()
|
||||
utils_proxy = UtilsProxy()
|
||||
|
||||
folder_path = folder_proxy.get_annotated_filepath("demo.png[input]")
|
||||
temp_dir = folder_proxy.get_temp_directory()
|
||||
models_dir = folder_proxy.models_dir
|
||||
asyncio.run(utils_proxy.progress_bar_hook(2, 5, node_id="node-17"))
|
||||
progress_proxy.set_progress(1.5, 5.0, node_id="node-17")
|
||||
|
||||
imported = set(sys.modules) - before
|
||||
return {
|
||||
"mode": "sealed_singletons",
|
||||
"folder_path": folder_path,
|
||||
"temp_dir": temp_dir,
|
||||
"models_dir": models_dir,
|
||||
"rpc_calls": fake_rpc.calls,
|
||||
"modules": sorted(imported),
|
||||
"forbidden_matches": matching_modules(FORBIDDEN_SEALED_SINGLETON_MODULES, imported),
|
||||
}
|
||||
finally:
|
||||
_clear_proxy_rpcs()
|
||||
if previous_child is None:
|
||||
os.environ.pop("PYISOLATE_CHILD", None)
|
||||
else:
|
||||
os.environ["PYISOLATE_CHILD"] = previous_child
|
||||
if previous_import_torch is None:
|
||||
os.environ.pop("PYISOLATE_IMPORT_TORCH", None)
|
||||
else:
|
||||
os.environ["PYISOLATE_IMPORT_TORCH"] = previous_import_torch
|
||||
69
tests/isolation/stage_internal_probe_node.py
Normal file
69
tests/isolation/stage_internal_probe_node.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
|
||||
COMFYUI_ROOT = Path(__file__).resolve().parents[2]
|
||||
PROBE_SOURCE_ROOT = COMFYUI_ROOT / "tests" / "isolation" / "internal_probe_node"
|
||||
PROBE_NODE_NAME = "InternalIsolationProbeNode"
|
||||
|
||||
PYPROJECT_CONTENT = """[project]
|
||||
name = "InternalIsolationProbeNode"
|
||||
version = "0.0.1"
|
||||
|
||||
[tool.comfy.isolation]
|
||||
can_isolate = true
|
||||
share_torch = true
|
||||
"""
|
||||
|
||||
|
||||
def _probe_target_root(comfy_root: Path) -> Path:
|
||||
return Path(comfy_root) / "custom_nodes" / PROBE_NODE_NAME
|
||||
|
||||
|
||||
def stage_probe_node(comfy_root: Path) -> Path:
|
||||
if not PROBE_SOURCE_ROOT.is_dir():
|
||||
raise RuntimeError(f"Missing probe source directory: {PROBE_SOURCE_ROOT}")
|
||||
|
||||
target_root = _probe_target_root(comfy_root)
|
||||
target_root.mkdir(parents=True, exist_ok=True)
|
||||
for source_path in PROBE_SOURCE_ROOT.iterdir():
|
||||
destination_path = target_root / source_path.name
|
||||
if source_path.is_dir():
|
||||
shutil.copytree(source_path, destination_path, dirs_exist_ok=True)
|
||||
else:
|
||||
shutil.copy2(source_path, destination_path)
|
||||
|
||||
(target_root / "pyproject.toml").write_text(PYPROJECT_CONTENT, encoding="utf-8")
|
||||
return target_root
|
||||
|
||||
|
||||
@contextmanager
|
||||
def staged_probe_node() -> Iterator[Path]:
|
||||
staging_root = Path(tempfile.mkdtemp(prefix="comfyui_internal_probe_"))
|
||||
try:
|
||||
yield stage_probe_node(staging_root)
|
||||
finally:
|
||||
shutil.rmtree(staging_root, ignore_errors=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Stage the internal isolation probe node under an explicit ComfyUI root."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-root",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Explicit ComfyUI root to stage under. Caller owns cleanup.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
staged = stage_probe_node(args.target_root)
|
||||
sys.stdout.write(f"{staged}\n")
|
||||
122
tests/isolation/test_client_snapshot.py
Normal file
122
tests/isolation/test_client_snapshot.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Tests for pyisolate._internal.client import-time snapshot handling."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Paths needed for subprocess
|
||||
PYISOLATE_ROOT = str(Path(__file__).parent.parent)
|
||||
COMFYUI_ROOT = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI")
|
||||
|
||||
SCRIPT = """
|
||||
import json, sys
|
||||
import pyisolate._internal.client # noqa: F401 # triggers snapshot logic
|
||||
print(json.dumps(sys.path[:6]))
|
||||
"""
|
||||
|
||||
|
||||
def _run_client_process(env):
|
||||
# Ensure subprocess can find pyisolate and ComfyUI
|
||||
pythonpath_parts = [PYISOLATE_ROOT, COMFYUI_ROOT]
|
||||
existing = env.get("PYTHONPATH", "")
|
||||
if existing:
|
||||
pythonpath_parts.append(existing)
|
||||
env["PYTHONPATH"] = os.pathsep.join(pythonpath_parts)
|
||||
|
||||
result = subprocess.run( # noqa: S603
|
||||
[sys.executable, "-c", SCRIPT],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
stdout = result.stdout.strip().splitlines()[-1]
|
||||
return json.loads(stdout)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def comfy_module_path(tmp_path):
|
||||
comfy_root = tmp_path / "ComfyUI"
|
||||
module_path = comfy_root / "custom_nodes" / "TestNode"
|
||||
module_path.mkdir(parents=True)
|
||||
return comfy_root, module_path
|
||||
|
||||
|
||||
def test_snapshot_applied_and_comfy_root_prepend(tmp_path, comfy_module_path):
|
||||
comfy_root, module_path = comfy_module_path
|
||||
# Must include real ComfyUI path for utils validation to pass
|
||||
host_paths = [COMFYUI_ROOT, "/host/lib1", "/host/lib2"]
|
||||
snapshot = {
|
||||
"sys_path": host_paths,
|
||||
"sys_executable": sys.executable,
|
||||
"sys_prefix": sys.prefix,
|
||||
"environment": {},
|
||||
}
|
||||
snapshot_path = tmp_path / "snapshot.json"
|
||||
snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8")
|
||||
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"PYISOLATE_CHILD": "1",
|
||||
"PYISOLATE_HOST_SNAPSHOT": str(snapshot_path),
|
||||
"PYISOLATE_MODULE_PATH": str(module_path),
|
||||
}
|
||||
)
|
||||
|
||||
path_prefix = _run_client_process(env)
|
||||
|
||||
# Current client behavior preserves the runtime bootstrap path order and
|
||||
# keeps the resolved ComfyUI root available for imports.
|
||||
assert COMFYUI_ROOT in path_prefix
|
||||
# Module path should not override runtime root selection.
|
||||
assert str(comfy_root) not in path_prefix
|
||||
|
||||
|
||||
def test_missing_snapshot_file_does_not_crash(tmp_path, comfy_module_path):
|
||||
_, module_path = comfy_module_path
|
||||
missing_snapshot = tmp_path / "missing.json"
|
||||
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"PYISOLATE_CHILD": "1",
|
||||
"PYISOLATE_HOST_SNAPSHOT": str(missing_snapshot),
|
||||
"PYISOLATE_MODULE_PATH": str(module_path),
|
||||
}
|
||||
)
|
||||
|
||||
# Should not raise even though snapshot path is missing
|
||||
paths = _run_client_process(env)
|
||||
assert len(paths) > 0
|
||||
|
||||
|
||||
def test_no_comfy_root_when_module_path_absent(tmp_path):
|
||||
# Must include real ComfyUI path for utils validation to pass
|
||||
host_paths = [COMFYUI_ROOT, "/alpha", "/beta"]
|
||||
snapshot = {
|
||||
"sys_path": host_paths,
|
||||
"sys_executable": sys.executable,
|
||||
"sys_prefix": sys.prefix,
|
||||
"environment": {},
|
||||
}
|
||||
snapshot_path = tmp_path / "snapshot.json"
|
||||
snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8")
|
||||
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"PYISOLATE_CHILD": "1",
|
||||
"PYISOLATE_HOST_SNAPSHOT": str(snapshot_path),
|
||||
}
|
||||
)
|
||||
|
||||
paths = _run_client_process(env)
|
||||
# Runtime path bootstrap keeps ComfyUI importability regardless of host
|
||||
# snapshot extras.
|
||||
assert COMFYUI_ROOT in paths
|
||||
assert "/alpha" not in paths and "/beta" not in paths
|
||||
460
tests/isolation/test_cuda_wheels_and_env_flags.py
Normal file
460
tests/isolation/test_cuda_wheels_and_env_flags.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""Synthetic integration coverage for manifest plumbing and env flags.
|
||||
|
||||
These tests do not perform a real wheel install or a real ComfyUI E2E run.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
import comfy.isolation as isolation_pkg
|
||||
from comfy.isolation import runtime_helpers
|
||||
from comfy.isolation import extension_loader as extension_loader_module
|
||||
from comfy.isolation import extension_wrapper as extension_wrapper_module
|
||||
from comfy.isolation import model_patcher_proxy_utils
|
||||
from comfy.isolation.extension_loader import ExtensionLoadError, load_isolated_node
|
||||
from comfy.isolation.extension_wrapper import ComfyNodeExtension
|
||||
from comfy.isolation.model_patcher_proxy_utils import maybe_wrap_model_for_isolation
|
||||
from pyisolate._internal.environment_conda import _generate_pixi_toml
|
||||
|
||||
|
||||
class _DummyExtension:
|
||||
def __init__(self) -> None:
|
||||
self.name = "demo-extension"
|
||||
|
||||
async def stop(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _write_manifest(node_dir, manifest_text: str) -> None:
|
||||
(node_dir / "pyproject.toml").write_text(manifest_text, encoding="utf-8")
|
||||
|
||||
|
||||
def test_load_isolated_node_passes_normalized_cuda_wheels_config(tmp_path, monkeypatch):
|
||||
node_dir = tmp_path / "node"
|
||||
node_dir.mkdir()
|
||||
manifest_path = node_dir / "pyproject.toml"
|
||||
_write_manifest(
|
||||
node_dir,
|
||||
"""
|
||||
[project]
|
||||
name = "demo-node"
|
||||
dependencies = ["flash-attn>=1.0", "sageattention==0.1"]
|
||||
|
||||
[tool.comfy.isolation]
|
||||
can_isolate = true
|
||||
share_torch = true
|
||||
|
||||
[tool.comfy.isolation.cuda_wheels]
|
||||
index_url = "https://example.invalid/cuda-wheels"
|
||||
packages = ["flash_attn", "sageattention"]
|
||||
|
||||
[tool.comfy.isolation.cuda_wheels.package_map]
|
||||
flash_attn = "flash-attn-special"
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyManager:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
def load_extension(self, config):
|
||||
captured.update(config)
|
||||
return _DummyExtension()
|
||||
|
||||
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_host_policy",
|
||||
lambda base_path: {
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_from_cache",
|
||||
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
|
||||
|
||||
specs = asyncio.run(
|
||||
load_isolated_node(
|
||||
node_dir,
|
||||
manifest_path,
|
||||
logging.getLogger("test"),
|
||||
lambda *args, **kwargs: object,
|
||||
tmp_path / "venvs",
|
||||
[],
|
||||
)
|
||||
)
|
||||
|
||||
assert len(specs) == 1
|
||||
assert captured["sandbox_mode"] == "required"
|
||||
assert captured["cuda_wheels"] == {
|
||||
"index_url": "https://example.invalid/cuda-wheels/",
|
||||
"packages": ["flash-attn", "sageattention"],
|
||||
"package_map": {"flash-attn": "flash-attn-special"},
|
||||
}
|
||||
|
||||
|
||||
def test_load_isolated_node_rejects_undeclared_cuda_wheel_dependency(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
node_dir = tmp_path / "node"
|
||||
node_dir.mkdir()
|
||||
manifest_path = node_dir / "pyproject.toml"
|
||||
_write_manifest(
|
||||
node_dir,
|
||||
"""
|
||||
[project]
|
||||
name = "demo-node"
|
||||
dependencies = ["numpy>=1.0"]
|
||||
|
||||
[tool.comfy.isolation]
|
||||
can_isolate = true
|
||||
|
||||
[tool.comfy.isolation.cuda_wheels]
|
||||
index_url = "https://example.invalid/cuda-wheels"
|
||||
packages = ["flash-attn"]
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
|
||||
|
||||
with pytest.raises(ExtensionLoadError, match="undeclared dependencies"):
|
||||
asyncio.run(
|
||||
load_isolated_node(
|
||||
node_dir,
|
||||
manifest_path,
|
||||
logging.getLogger("test"),
|
||||
lambda *args, **kwargs: object,
|
||||
tmp_path / "venvs",
|
||||
[],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_conda_cuda_wheels_declared_packages_do_not_force_pixi_solve(tmp_path, monkeypatch):
|
||||
node_dir = tmp_path / "node"
|
||||
node_dir.mkdir()
|
||||
manifest_path = node_dir / "pyproject.toml"
|
||||
_write_manifest(
|
||||
node_dir,
|
||||
"""
|
||||
[project]
|
||||
name = "demo-node"
|
||||
dependencies = ["numpy>=1.0", "spconv", "cumm", "flash-attn"]
|
||||
|
||||
[tool.comfy.isolation]
|
||||
can_isolate = true
|
||||
package_manager = "conda"
|
||||
conda_channels = ["conda-forge"]
|
||||
|
||||
[tool.comfy.isolation.cuda_wheels]
|
||||
index_url = "https://example.invalid/cuda-wheels"
|
||||
packages = ["spconv", "cumm", "flash-attn"]
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyManager:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
def load_extension(self, config):
|
||||
captured.update(config)
|
||||
return _DummyExtension()
|
||||
|
||||
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_host_policy",
|
||||
lambda base_path: {
|
||||
"sandbox_mode": "disabled",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_from_cache",
|
||||
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
|
||||
|
||||
asyncio.run(
|
||||
load_isolated_node(
|
||||
node_dir,
|
||||
manifest_path,
|
||||
logging.getLogger("test"),
|
||||
lambda *args, **kwargs: object,
|
||||
tmp_path / "venvs",
|
||||
[],
|
||||
)
|
||||
)
|
||||
|
||||
generated = _generate_pixi_toml(captured)
|
||||
assert 'numpy = ">=1.0"' in generated
|
||||
assert "spconv =" not in generated
|
||||
assert "cumm =" not in generated
|
||||
assert "flash-attn =" not in generated
|
||||
|
||||
|
||||
def test_conda_cuda_wheels_loader_accepts_sam3d_contract(tmp_path, monkeypatch):
|
||||
node_dir = tmp_path / "node"
|
||||
node_dir.mkdir()
|
||||
manifest_path = node_dir / "pyproject.toml"
|
||||
_write_manifest(
|
||||
node_dir,
|
||||
"""
|
||||
[project]
|
||||
name = "demo-node"
|
||||
dependencies = [
|
||||
"torch",
|
||||
"torchvision",
|
||||
"pytorch3d",
|
||||
"gsplat",
|
||||
"nvdiffrast",
|
||||
"flash-attn",
|
||||
"sageattention",
|
||||
"spconv",
|
||||
"cumm",
|
||||
]
|
||||
|
||||
[tool.comfy.isolation]
|
||||
can_isolate = true
|
||||
package_manager = "conda"
|
||||
conda_channels = ["conda-forge"]
|
||||
|
||||
[tool.comfy.isolation.cuda_wheels]
|
||||
index_url = "https://example.invalid/cuda-wheels"
|
||||
packages = ["pytorch3d", "gsplat", "nvdiffrast", "flash-attn", "sageattention", "spconv", "cumm"]
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyManager:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
def load_extension(self, config):
|
||||
captured.update(config)
|
||||
return _DummyExtension()
|
||||
|
||||
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_host_policy",
|
||||
lambda base_path: {
|
||||
"sandbox_mode": "disabled",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_from_cache",
|
||||
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
|
||||
|
||||
asyncio.run(
|
||||
load_isolated_node(
|
||||
node_dir,
|
||||
manifest_path,
|
||||
logging.getLogger("test"),
|
||||
lambda *args, **kwargs: object,
|
||||
tmp_path / "venvs",
|
||||
[],
|
||||
)
|
||||
)
|
||||
|
||||
assert captured["package_manager"] == "conda"
|
||||
assert captured["cuda_wheels"] == {
|
||||
"index_url": "https://example.invalid/cuda-wheels/",
|
||||
"packages": [
|
||||
"pytorch3d",
|
||||
"gsplat",
|
||||
"nvdiffrast",
|
||||
"flash-attn",
|
||||
"sageattention",
|
||||
"spconv",
|
||||
"cumm",
|
||||
],
|
||||
"package_map": {},
|
||||
}
|
||||
|
||||
|
||||
def test_load_isolated_node_omits_cuda_wheels_when_not_configured(tmp_path, monkeypatch):
|
||||
node_dir = tmp_path / "node"
|
||||
node_dir.mkdir()
|
||||
manifest_path = node_dir / "pyproject.toml"
|
||||
_write_manifest(
|
||||
node_dir,
|
||||
"""
|
||||
[project]
|
||||
name = "demo-node"
|
||||
dependencies = ["numpy>=1.0"]
|
||||
|
||||
[tool.comfy.isolation]
|
||||
can_isolate = true
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyManager:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
def load_extension(self, config):
|
||||
captured.update(config)
|
||||
return _DummyExtension()
|
||||
|
||||
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_host_policy",
|
||||
lambda base_path: {
|
||||
"sandbox_mode": "disabled",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_from_cache",
|
||||
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
|
||||
|
||||
asyncio.run(
|
||||
load_isolated_node(
|
||||
node_dir,
|
||||
manifest_path,
|
||||
logging.getLogger("test"),
|
||||
lambda *args, **kwargs: object,
|
||||
tmp_path / "venvs",
|
||||
[],
|
||||
)
|
||||
)
|
||||
|
||||
assert captured["sandbox_mode"] == "disabled"
|
||||
assert "cuda_wheels" not in captured
|
||||
|
||||
|
||||
def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch):
|
||||
class DummyRegistry:
|
||||
def register(self, model):
|
||||
return "model-123"
|
||||
|
||||
class DummyProxy:
|
||||
def __init__(self, model_id, registry, manage_lifecycle):
|
||||
self.model_id = model_id
|
||||
self.registry = registry
|
||||
self.manage_lifecycle = manage_lifecycle
|
||||
|
||||
monkeypatch.setattr(model_patcher_proxy_utils.args, "use_process_isolation", True)
|
||||
monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False)
|
||||
monkeypatch.delenv("PYISOLATE_CHILD", raising=False)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"comfy.isolation.model_patcher_proxy_registry",
|
||||
SimpleNamespace(ModelPatcherRegistry=DummyRegistry),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"comfy.isolation.model_patcher_proxy",
|
||||
SimpleNamespace(ModelPatcherProxy=DummyProxy),
|
||||
)
|
||||
|
||||
wrapped = cast(Any, maybe_wrap_model_for_isolation(object()))
|
||||
|
||||
assert isinstance(wrapped, DummyProxy)
|
||||
assert getattr(wrapped, "model_id") == "model-123"
|
||||
assert getattr(wrapped, "manage_lifecycle") is True
|
||||
|
||||
|
||||
def test_flush_transport_state_uses_child_env_without_legacy_flag(monkeypatch):
|
||||
monkeypatch.setenv("PYISOLATE_CHILD", "1")
|
||||
monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False)
|
||||
monkeypatch.setattr(extension_wrapper_module, "_flush_tensor_transport_state", lambda marker: 3)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"comfy.isolation.model_patcher_proxy_registry",
|
||||
SimpleNamespace(
|
||||
ModelPatcherRegistry=lambda: SimpleNamespace(
|
||||
sweep_pending_cleanup=lambda: 0
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
flushed = asyncio.run(
|
||||
ComfyNodeExtension.flush_transport_state(SimpleNamespace(name="demo"))
|
||||
)
|
||||
|
||||
assert flushed == 3
|
||||
|
||||
|
||||
def test_build_stub_class_relieves_host_vram_without_legacy_flag(monkeypatch):
|
||||
relieve_calls: list[str] = []
|
||||
|
||||
async def deserialize_from_isolation(result, extension):
|
||||
return result
|
||||
|
||||
monkeypatch.delenv("PYISOLATE_CHILD", raising=False)
|
||||
monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False)
|
||||
monkeypatch.setattr(
|
||||
runtime_helpers, "_relieve_host_vram_pressure", lambda marker, logger: relieve_calls.append(marker)
|
||||
)
|
||||
monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(isolation_pkg, "_RUNNING_EXTENSIONS", {}, raising=False)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"pyisolate._internal.model_serialization",
|
||||
SimpleNamespace(
|
||||
serialize_for_isolation=lambda payload: payload,
|
||||
deserialize_from_isolation=deserialize_from_isolation,
|
||||
),
|
||||
)
|
||||
|
||||
class DummyExtension:
|
||||
name = "demo-extension"
|
||||
module_path = os.getcwd()
|
||||
|
||||
async def execute_node(self, node_name, **inputs):
|
||||
return inputs
|
||||
|
||||
stub_cls = runtime_helpers.build_stub_class(
|
||||
"DemoNode",
|
||||
{"input_types": {}},
|
||||
DummyExtension(),
|
||||
{},
|
||||
logging.getLogger("test"),
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
getattr(stub_cls, "_pyisolate_execute")(SimpleNamespace(), value=1)
|
||||
)
|
||||
|
||||
assert relieve_calls == ["RUNTIME:pre_execute"]
|
||||
assert result == {"value": 1}
|
||||
22
tests/isolation/test_exact_proxy_bootstrap_contract.py
Normal file
22
tests/isolation/test_exact_proxy_bootstrap_contract.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from tests.isolation.singleton_boundary_helpers import (
|
||||
capture_exact_proxy_bootstrap_contract,
|
||||
)
|
||||
|
||||
|
||||
def test_no_proxy_omission_allowed() -> None:
|
||||
payload = capture_exact_proxy_bootstrap_contract()
|
||||
|
||||
assert payload["omitted_proxies"] == []
|
||||
assert payload["forbidden_matches"] == []
|
||||
|
||||
matrix = payload["matrix"]
|
||||
assert matrix["base.py"]["bound"] is True
|
||||
assert matrix["folder_paths_proxy.py"]["bound"] is True
|
||||
assert matrix["helper_proxies.py"]["bound"] is True
|
||||
assert matrix["model_management_proxy.py"]["bound"] is True
|
||||
assert matrix["progress_proxy.py"]["bound"] is True
|
||||
assert matrix["prompt_server_impl.py"]["bound"] is True
|
||||
assert matrix["utils_proxy.py"]["bound"] is True
|
||||
assert matrix["web_directory_proxy.py"]["bound"] is True
|
||||
128
tests/isolation/test_exact_proxy_relay_matrix.py
Normal file
128
tests/isolation/test_exact_proxy_relay_matrix.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from tests.isolation.singleton_boundary_helpers import (
|
||||
capture_exact_small_proxy_relay,
|
||||
capture_model_management_exact_relay,
|
||||
capture_prompt_web_exact_relay,
|
||||
)
|
||||
|
||||
|
||||
def _transcripts_for(payload: dict[str, object], object_id: str, method: str) -> list[dict[str, object]]:
|
||||
return [
|
||||
entry
|
||||
for entry in payload["transcripts"]
|
||||
if entry["object_id"] == object_id and entry["method"] == method
|
||||
]
|
||||
|
||||
|
||||
def test_folder_paths_exact_relay() -> None:
|
||||
payload = capture_exact_small_proxy_relay()
|
||||
|
||||
assert payload["forbidden_matches"] == []
|
||||
assert payload["models_dir"] == "/sandbox/models"
|
||||
assert payload["folder_path"] == "/sandbox/input/demo.png"
|
||||
|
||||
models_dir_calls = _transcripts_for(payload, "FolderPathsProxy", "rpc_get_models_dir")
|
||||
annotated_calls = _transcripts_for(payload, "FolderPathsProxy", "rpc_get_annotated_filepath")
|
||||
|
||||
assert models_dir_calls
|
||||
assert annotated_calls
|
||||
assert all(entry["phase"] != "child_call" or entry["method"] != "rpc_snapshot" for entry in payload["transcripts"])
|
||||
|
||||
|
||||
def test_progress_exact_relay() -> None:
|
||||
payload = capture_exact_small_proxy_relay()
|
||||
|
||||
progress_calls = _transcripts_for(payload, "ProgressProxy", "rpc_set_progress")
|
||||
|
||||
assert progress_calls
|
||||
host_targets = [entry["target"] for entry in progress_calls if entry["phase"] == "host_invocation"]
|
||||
assert host_targets == ["comfy_execution.progress.get_progress_state().update_progress"]
|
||||
result_entries = [entry for entry in progress_calls if entry["phase"] == "result"]
|
||||
assert result_entries == [{"phase": "result", "object_id": "ProgressProxy", "method": "rpc_set_progress", "result": None}]
|
||||
|
||||
|
||||
def test_utils_exact_relay() -> None:
|
||||
payload = capture_exact_small_proxy_relay()
|
||||
|
||||
utils_calls = _transcripts_for(payload, "UtilsProxy", "progress_bar_hook")
|
||||
|
||||
assert utils_calls
|
||||
host_targets = [entry["target"] for entry in utils_calls if entry["phase"] == "host_invocation"]
|
||||
assert host_targets == ["comfy.utils.PROGRESS_BAR_HOOK"]
|
||||
result_entries = [entry for entry in utils_calls if entry["phase"] == "result"]
|
||||
assert result_entries
|
||||
assert result_entries[0]["result"]["value"] == 2
|
||||
assert result_entries[0]["result"]["total"] == 5
|
||||
|
||||
|
||||
def test_helper_proxy_exact_relay() -> None:
|
||||
payload = capture_exact_small_proxy_relay()
|
||||
|
||||
helper_calls = _transcripts_for(payload, "HelperProxiesService", "rpc_restore_input_types")
|
||||
|
||||
assert helper_calls
|
||||
host_targets = [entry["target"] for entry in helper_calls if entry["phase"] == "host_invocation"]
|
||||
assert host_targets == ["comfy.isolation.proxies.helper_proxies.restore_input_types"]
|
||||
assert payload["restored_any_type"] == "*"
|
||||
|
||||
|
||||
def test_model_management_exact_relay() -> None:
|
||||
payload = capture_model_management_exact_relay()
|
||||
|
||||
model_calls = _transcripts_for(payload, "ModelManagementProxy", "get_torch_device")
|
||||
model_calls += _transcripts_for(payload, "ModelManagementProxy", "get_torch_device_name")
|
||||
model_calls += _transcripts_for(payload, "ModelManagementProxy", "get_free_memory")
|
||||
|
||||
assert payload["forbidden_matches"] == []
|
||||
assert model_calls
|
||||
host_targets = [
|
||||
entry["target"]
|
||||
for entry in payload["transcripts"]
|
||||
if entry["phase"] == "host_invocation"
|
||||
]
|
||||
assert host_targets == [
|
||||
"comfy.model_management.get_torch_device",
|
||||
"comfy.model_management.get_torch_device_name",
|
||||
"comfy.model_management.get_free_memory",
|
||||
]
|
||||
|
||||
|
||||
def test_model_management_capability_preserved() -> None:
|
||||
payload = capture_model_management_exact_relay()
|
||||
|
||||
assert payload["device"] == "cpu"
|
||||
assert payload["device_type"] == "cpu"
|
||||
assert payload["device_name"] == "cpu"
|
||||
assert payload["free_memory"] == 34359738368
|
||||
|
||||
|
||||
def test_prompt_server_exact_relay() -> None:
|
||||
payload = capture_prompt_web_exact_relay()
|
||||
|
||||
prompt_calls = _transcripts_for(payload, "PromptServerService", "ui_send_progress_text")
|
||||
prompt_calls += _transcripts_for(payload, "PromptServerService", "register_route_rpc")
|
||||
|
||||
assert payload["forbidden_matches"] == []
|
||||
assert prompt_calls
|
||||
host_targets = [
|
||||
entry["target"]
|
||||
for entry in payload["transcripts"]
|
||||
if entry["object_id"] == "PromptServerService" and entry["phase"] == "host_invocation"
|
||||
]
|
||||
assert host_targets == [
|
||||
"server.PromptServer.instance.send_progress_text",
|
||||
"server.PromptServer.instance.routes.add_route",
|
||||
]
|
||||
|
||||
|
||||
def test_web_directory_exact_relay() -> None:
|
||||
payload = capture_prompt_web_exact_relay()
|
||||
|
||||
web_calls = _transcripts_for(payload, "WebDirectoryProxy", "get_web_file")
|
||||
|
||||
assert web_calls
|
||||
host_targets = [entry["target"] for entry in web_calls if entry["phase"] == "host_invocation"]
|
||||
assert host_targets == ["comfy.isolation.proxies.web_directory_proxy.WebDirectoryProxy.get_web_file"]
|
||||
assert payload["web_file"]["content_type"] == "application/javascript"
|
||||
assert payload["web_file"]["content"] == "console.log('deo');"
|
||||
428
tests/isolation/test_extension_loader_conda.py
Normal file
428
tests/isolation/test_extension_loader_conda.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""Tests for conda config parsing in extension_loader.py (Slice 5).
|
||||
|
||||
These tests verify that extension_loader.py correctly parses conda-related
|
||||
fields from pyproject.toml manifests and passes them into the extension config
|
||||
dict given to pyisolate. The torch import chain is broken by pre-mocking
|
||||
extension_wrapper before importing extension_loader.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_manifest(
|
||||
*,
|
||||
package_manager: str = "uv",
|
||||
conda_channels: list[str] | None = None,
|
||||
conda_dependencies: list[str] | None = None,
|
||||
conda_platforms: list[str] | None = None,
|
||||
share_torch: bool = False,
|
||||
can_isolate: bool = True,
|
||||
dependencies: list[str] | None = None,
|
||||
cuda_wheels: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Build a manifest dict matching tomllib.load() output."""
|
||||
isolation: dict = {"can_isolate": can_isolate}
|
||||
if package_manager != "uv":
|
||||
isolation["package_manager"] = package_manager
|
||||
if conda_channels is not None:
|
||||
isolation["conda_channels"] = conda_channels
|
||||
if conda_dependencies is not None:
|
||||
isolation["conda_dependencies"] = conda_dependencies
|
||||
if conda_platforms is not None:
|
||||
isolation["conda_platforms"] = conda_platforms
|
||||
if share_torch:
|
||||
isolation["share_torch"] = True
|
||||
if cuda_wheels is not None:
|
||||
isolation["cuda_wheels"] = cuda_wheels
|
||||
|
||||
return {
|
||||
"project": {
|
||||
"name": "test-extension",
|
||||
"dependencies": dependencies or ["numpy"],
|
||||
},
|
||||
"tool": {"comfy": {"isolation": isolation}},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manifest_file(tmp_path):
|
||||
"""Create a dummy pyproject.toml so manifest_path.open('rb') succeeds."""
|
||||
path = tmp_path / "pyproject.toml"
|
||||
path.write_bytes(b"") # content is overridden by tomllib mock
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader_module(monkeypatch):
|
||||
"""Import extension_loader under a mocked isolation package for this test only."""
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {})
|
||||
|
||||
iso_mod = types.ModuleType("comfy.isolation")
|
||||
iso_mod.__path__ = [ # type: ignore[attr-defined]
|
||||
str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation")
|
||||
]
|
||||
iso_mod.__package__ = "comfy.isolation"
|
||||
|
||||
manifest_loader = types.SimpleNamespace(
|
||||
is_cache_valid=lambda *args, **kwargs: False,
|
||||
load_from_cache=lambda *args, **kwargs: None,
|
||||
save_to_cache=lambda *args, **kwargs: None,
|
||||
)
|
||||
host_policy = types.SimpleNamespace(
|
||||
load_host_policy=lambda base_path: {
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
}
|
||||
)
|
||||
folder_paths = types.SimpleNamespace(base_path="/fake/comfyui")
|
||||
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod)
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper)
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock())
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader)
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy)
|
||||
monkeypatch.setitem(sys.modules, "folder_paths", folder_paths)
|
||||
sys.modules.pop("comfy.isolation.extension_loader", None)
|
||||
|
||||
module = importlib.import_module("comfy.isolation.extension_loader")
|
||||
try:
|
||||
yield module, mock_wrapper
|
||||
finally:
|
||||
sys.modules.pop("comfy.isolation.extension_loader", None)
|
||||
comfy_pkg = sys.modules.get("comfy")
|
||||
if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"):
|
||||
delattr(comfy_pkg, "isolation")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pyisolate(loader_module):
|
||||
"""Mock pyisolate to avoid real venv creation."""
|
||||
module, mock_wrapper = loader_module
|
||||
mock_ext = AsyncMock()
|
||||
mock_ext.list_nodes = AsyncMock(return_value={})
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.load_extension = MagicMock(return_value=mock_ext)
|
||||
sealed_type = type("SealedNodeExtension", (), {})
|
||||
|
||||
with patch.object(module, "pyisolate") as mock_pi:
|
||||
mock_pi.ExtensionManager = MagicMock(return_value=mock_manager)
|
||||
mock_pi.SealedNodeExtension = sealed_type
|
||||
yield module, mock_pi, mock_manager, mock_ext, mock_wrapper
|
||||
|
||||
|
||||
def load_isolated_node(*args, **kwargs):
|
||||
return sys.modules["comfy.isolation.extension_loader"].load_isolated_node(
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class TestCondaPackageManagerParsing:
|
||||
"""Verify extension_loader.py parses conda config from pyproject.toml."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_package_manager_in_config(
|
||||
self, mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
"""package_manager='conda' must appear in extension_config."""
|
||||
|
||||
manifest = _make_manifest(
|
||||
package_manager="conda",
|
||||
conda_channels=["conda-forge"],
|
||||
conda_dependencies=["eccodes"],
|
||||
)
|
||||
|
||||
_, _, mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert config["package_manager"] == "conda"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_channels_in_config(
|
||||
self, mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
"""conda_channels must be passed through to extension_config."""
|
||||
|
||||
manifest = _make_manifest(
|
||||
package_manager="conda",
|
||||
conda_channels=["conda-forge", "nvidia"],
|
||||
conda_dependencies=["eccodes"],
|
||||
)
|
||||
|
||||
_, _, mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert config["conda_channels"] == ["conda-forge", "nvidia"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_dependencies_in_config(
|
||||
self, mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
"""conda_dependencies must be passed through to extension_config."""
|
||||
|
||||
manifest = _make_manifest(
|
||||
package_manager="conda",
|
||||
conda_channels=["conda-forge"],
|
||||
conda_dependencies=["eccodes", "cfgrib"],
|
||||
)
|
||||
|
||||
_, _, mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert config["conda_dependencies"] == ["eccodes", "cfgrib"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_platforms_in_config(
|
||||
self, mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
"""conda_platforms must be passed through to extension_config."""
|
||||
|
||||
manifest = _make_manifest(
|
||||
package_manager="conda",
|
||||
conda_channels=["conda-forge"],
|
||||
conda_dependencies=["eccodes"],
|
||||
conda_platforms=["linux-64"],
|
||||
)
|
||||
|
||||
_, _, mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert config["conda_platforms"] == ["linux-64"]
|
||||
|
||||
|
||||
class TestCondaForcedOverrides:
|
||||
"""Verify conda forces share_torch=False, share_cuda_ipc=False."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_forces_share_torch_false(
|
||||
self, mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
"""share_torch must be forced False for conda, even if manifest says True."""
|
||||
|
||||
manifest = _make_manifest(
|
||||
package_manager="conda",
|
||||
conda_channels=["conda-forge"],
|
||||
conda_dependencies=["eccodes"],
|
||||
share_torch=True, # manifest requests True — must be overridden
|
||||
)
|
||||
|
||||
_, _, mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert config["share_torch"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_forces_share_cuda_ipc_false(
|
||||
self, mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
"""share_cuda_ipc must be forced False for conda."""
|
||||
|
||||
manifest = _make_manifest(
|
||||
package_manager="conda",
|
||||
conda_channels=["conda-forge"],
|
||||
conda_dependencies=["eccodes"],
|
||||
share_torch=True,
|
||||
)
|
||||
|
||||
_, _, mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert config["share_cuda_ipc"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_sealed_worker_uses_host_policy_sandbox_config(
|
||||
self, mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
"""Conda sealed_worker must carry the host-policy sandbox config on Linux."""
|
||||
|
||||
manifest = _make_manifest(
|
||||
package_manager="conda",
|
||||
conda_channels=["conda-forge"],
|
||||
conda_dependencies=["eccodes"],
|
||||
)
|
||||
|
||||
_, _, mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with (
|
||||
patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib,
|
||||
patch(
|
||||
"comfy.isolation.extension_loader.platform.system",
|
||||
return_value="Linux",
|
||||
),
|
||||
):
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert config["sandbox"] == {
|
||||
"network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_uses_sealed_extension_type(
|
||||
self, mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
"""Conda must not launch through ComfyNodeExtension."""
|
||||
|
||||
_, mock_pi, _, _, mock_wrapper = mock_pyisolate
|
||||
manifest = _make_manifest(
|
||||
package_manager="conda",
|
||||
conda_channels=["conda-forge"],
|
||||
conda_dependencies=["eccodes"],
|
||||
)
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
extension_type = mock_pi.ExtensionManager.call_args[0][0]
|
||||
assert extension_type.__name__ == "SealedNodeExtension"
|
||||
assert extension_type is not mock_wrapper.ComfyNodeExtension
|
||||
|
||||
|
||||
class TestUvUnchanged:
|
||||
"""Verify uv extensions are NOT affected by conda changes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uv_default_no_conda_keys(
|
||||
self, mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
"""Default uv extension must NOT have package_manager or conda keys."""
|
||||
|
||||
manifest = _make_manifest() # defaults: uv, no conda fields
|
||||
|
||||
_, _, mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
# uv extensions should not have conda-specific keys
|
||||
assert config.get("package_manager", "uv") == "uv"
|
||||
assert "conda_channels" not in config
|
||||
assert "conda_dependencies" not in config
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uv_keeps_comfy_extension_type(
|
||||
self, mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
"""uv keeps the existing ComfyNodeExtension path."""
|
||||
|
||||
_, mock_pi, _, _, _ = mock_pyisolate
|
||||
manifest = _make_manifest()
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
extension_type = mock_pi.ExtensionManager.call_args[0][0]
|
||||
assert extension_type.__name__ == "ComfyNodeExtension"
|
||||
assert extension_type is not mock_pi.SealedNodeExtension
|
||||
281
tests/isolation/test_extension_loader_sealed_worker.py
Normal file
281
tests/isolation/test_extension_loader_sealed_worker.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Tests for execution_model parsing and sealed-worker loader selection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_manifest(
|
||||
*,
|
||||
package_manager: str = "uv",
|
||||
execution_model: str | None = None,
|
||||
can_isolate: bool = True,
|
||||
dependencies: list[str] | None = None,
|
||||
sealed_host_ro_paths: list[str] | None = None,
|
||||
) -> dict:
|
||||
isolation: dict = {"can_isolate": can_isolate}
|
||||
if package_manager != "uv":
|
||||
isolation["package_manager"] = package_manager
|
||||
if execution_model is not None:
|
||||
isolation["execution_model"] = execution_model
|
||||
if sealed_host_ro_paths is not None:
|
||||
isolation["sealed_host_ro_paths"] = sealed_host_ro_paths
|
||||
|
||||
return {
|
||||
"project": {
|
||||
"name": "test-extension",
|
||||
"dependencies": dependencies or ["numpy"],
|
||||
},
|
||||
"tool": {"comfy": {"isolation": isolation}},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manifest_file(tmp_path):
|
||||
path = tmp_path / "pyproject.toml"
|
||||
path.write_bytes(b"")
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader_module(monkeypatch):
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {})
|
||||
|
||||
iso_mod = types.ModuleType("comfy.isolation")
|
||||
iso_mod.__path__ = [ # type: ignore[attr-defined]
|
||||
str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation")
|
||||
]
|
||||
iso_mod.__package__ = "comfy.isolation"
|
||||
|
||||
manifest_loader = types.SimpleNamespace(
|
||||
is_cache_valid=lambda *args, **kwargs: False,
|
||||
load_from_cache=lambda *args, **kwargs: None,
|
||||
save_to_cache=lambda *args, **kwargs: None,
|
||||
)
|
||||
host_policy = types.SimpleNamespace(
|
||||
load_host_policy=lambda base_path: {
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
"sealed_worker_ro_import_paths": [],
|
||||
}
|
||||
)
|
||||
folder_paths = types.SimpleNamespace(base_path="/fake/comfyui")
|
||||
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod)
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper)
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock())
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader)
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy)
|
||||
monkeypatch.setitem(sys.modules, "folder_paths", folder_paths)
|
||||
sys.modules.pop("comfy.isolation.extension_loader", None)
|
||||
|
||||
module = importlib.import_module("comfy.isolation.extension_loader")
|
||||
try:
|
||||
yield module
|
||||
finally:
|
||||
sys.modules.pop("comfy.isolation.extension_loader", None)
|
||||
comfy_pkg = sys.modules.get("comfy")
|
||||
if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"):
|
||||
delattr(comfy_pkg, "isolation")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pyisolate(loader_module):
|
||||
mock_ext = AsyncMock()
|
||||
mock_ext.list_nodes = AsyncMock(return_value={})
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.load_extension = MagicMock(return_value=mock_ext)
|
||||
sealed_type = type("SealedNodeExtension", (), {})
|
||||
|
||||
with patch.object(loader_module, "pyisolate") as mock_pi:
|
||||
mock_pi.ExtensionManager = MagicMock(return_value=mock_manager)
|
||||
mock_pi.SealedNodeExtension = sealed_type
|
||||
yield loader_module, mock_pi, mock_manager, mock_ext, sealed_type
|
||||
|
||||
|
||||
def load_isolated_node(*args, **kwargs):
|
||||
return sys.modules["comfy.isolation.extension_loader"].load_isolated_node(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uv_sealed_worker_selects_sealed_extension_type(
|
||||
mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
manifest = _make_manifest(execution_model="sealed_worker")
|
||||
|
||||
_, mock_pi, mock_manager, _, sealed_type = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
extension_type = mock_pi.ExtensionManager.call_args[0][0]
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert extension_type is sealed_type
|
||||
assert config["execution_model"] == "sealed_worker"
|
||||
assert "apis" not in config
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_uv_keeps_host_coupled_extension_type(
|
||||
mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
manifest = _make_manifest()
|
||||
|
||||
_, mock_pi, mock_manager, _, sealed_type = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
extension_type = mock_pi.ExtensionManager.call_args[0][0]
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert extension_type is not sealed_type
|
||||
assert "execution_model" not in config
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_without_execution_model_remains_sealed_worker(
|
||||
mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
manifest = _make_manifest(package_manager="conda")
|
||||
manifest["tool"]["comfy"]["isolation"]["conda_channels"] = ["conda-forge"]
|
||||
manifest["tool"]["comfy"]["isolation"]["conda_dependencies"] = ["eccodes"]
|
||||
|
||||
_, mock_pi, mock_manager, _, sealed_type = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
extension_type = mock_pi.ExtensionManager.call_args[0][0]
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert extension_type is sealed_type
|
||||
assert config["execution_model"] == "sealed_worker"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sealed_worker_uses_host_policy_ro_import_paths(
|
||||
mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
manifest = _make_manifest(execution_model="sealed_worker")
|
||||
|
||||
module, _, mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with (
|
||||
patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib,
|
||||
patch.object(
|
||||
module,
|
||||
"load_host_policy",
|
||||
return_value={
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
"sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"],
|
||||
},
|
||||
),
|
||||
):
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert config["sealed_host_ro_paths"] == ["/home/johnj/ComfyUI"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_host_coupled_does_not_emit_sealed_host_ro_paths(
|
||||
mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
manifest = _make_manifest(execution_model="host-coupled")
|
||||
|
||||
module, _, mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with (
|
||||
patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib,
|
||||
patch.object(
|
||||
module,
|
||||
"load_host_policy",
|
||||
return_value={
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
"sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"],
|
||||
},
|
||||
),
|
||||
):
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
|
||||
config = mock_manager.load_extension.call_args[0][0]
|
||||
assert "sealed_host_ro_paths" not in config
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sealed_worker_manifest_ro_import_paths_blocked(
|
||||
mock_pyisolate, manifest_file, tmp_path
|
||||
):
|
||||
manifest = _make_manifest(
|
||||
execution_model="sealed_worker",
|
||||
sealed_host_ro_paths=["/home/johnj/ComfyUI"],
|
||||
)
|
||||
|
||||
_, _, _mock_manager, _, _ = mock_pyisolate
|
||||
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
with pytest.raises(ValueError, match="Manifest field 'sealed_host_ro_paths' is not allowed"):
|
||||
await load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_file,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
122
tests/isolation/test_folder_paths_proxy.py
Normal file
122
tests/isolation/test_folder_paths_proxy.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Unit tests for FolderPathsProxy."""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from tests.isolation.singleton_boundary_helpers import capture_sealed_singleton_imports
|
||||
|
||||
|
||||
class TestFolderPathsProxy:
|
||||
"""Test FolderPathsProxy methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def proxy(self):
|
||||
"""Create a FolderPathsProxy instance for testing."""
|
||||
return FolderPathsProxy()
|
||||
|
||||
def test_get_temp_directory_returns_string(self, proxy):
|
||||
"""Verify get_temp_directory returns a non-empty string."""
|
||||
result = proxy.get_temp_directory()
|
||||
assert isinstance(result, str), f"Expected str, got {type(result)}"
|
||||
assert len(result) > 0, "Temp directory path is empty"
|
||||
|
||||
def test_get_temp_directory_returns_absolute_path(self, proxy):
|
||||
"""Verify get_temp_directory returns an absolute path."""
|
||||
result = proxy.get_temp_directory()
|
||||
path = Path(result)
|
||||
assert path.is_absolute(), f"Path is not absolute: {result}"
|
||||
|
||||
def test_get_input_directory_returns_string(self, proxy):
|
||||
"""Verify get_input_directory returns a non-empty string."""
|
||||
result = proxy.get_input_directory()
|
||||
assert isinstance(result, str), f"Expected str, got {type(result)}"
|
||||
assert len(result) > 0, "Input directory path is empty"
|
||||
|
||||
def test_get_input_directory_returns_absolute_path(self, proxy):
|
||||
"""Verify get_input_directory returns an absolute path."""
|
||||
result = proxy.get_input_directory()
|
||||
path = Path(result)
|
||||
assert path.is_absolute(), f"Path is not absolute: {result}"
|
||||
|
||||
def test_get_annotated_filepath_plain_name(self, proxy):
|
||||
"""Verify get_annotated_filepath works with plain filename."""
|
||||
result = proxy.get_annotated_filepath("test.png")
|
||||
assert isinstance(result, str), f"Expected str, got {type(result)}"
|
||||
assert "test.png" in result, f"Filename not in result: {result}"
|
||||
|
||||
def test_get_annotated_filepath_with_output_annotation(self, proxy):
|
||||
"""Verify get_annotated_filepath handles [output] annotation."""
|
||||
result = proxy.get_annotated_filepath("test.png[output]")
|
||||
assert isinstance(result, str), f"Expected str, got {type(result)}"
|
||||
assert "test.pn" in result, f"Filename base not in result: {result}"
|
||||
# Should resolve to output directory
|
||||
assert "output" in result.lower() or Path(result).parent.name == "output"
|
||||
|
||||
def test_get_annotated_filepath_with_input_annotation(self, proxy):
|
||||
"""Verify get_annotated_filepath handles [input] annotation."""
|
||||
result = proxy.get_annotated_filepath("test.png[input]")
|
||||
assert isinstance(result, str), f"Expected str, got {type(result)}"
|
||||
assert "test.pn" in result, f"Filename base not in result: {result}"
|
||||
|
||||
def test_get_annotated_filepath_with_temp_annotation(self, proxy):
|
||||
"""Verify get_annotated_filepath handles [temp] annotation."""
|
||||
result = proxy.get_annotated_filepath("test.png[temp]")
|
||||
assert isinstance(result, str), f"Expected str, got {type(result)}"
|
||||
assert "test.pn" in result, f"Filename base not in result: {result}"
|
||||
|
||||
def test_exists_annotated_filepath_returns_bool(self, proxy):
|
||||
"""Verify exists_annotated_filepath returns a boolean."""
|
||||
result = proxy.exists_annotated_filepath("nonexistent.png")
|
||||
assert isinstance(result, bool), f"Expected bool, got {type(result)}"
|
||||
|
||||
def test_exists_annotated_filepath_nonexistent_file(self, proxy):
|
||||
"""Verify exists_annotated_filepath returns False for nonexistent file."""
|
||||
result = proxy.exists_annotated_filepath("definitely_does_not_exist_12345.png")
|
||||
assert result is False, "Expected False for nonexistent file"
|
||||
|
||||
def test_exists_annotated_filepath_with_annotation(self, proxy):
|
||||
"""Verify exists_annotated_filepath works with annotation suffix."""
|
||||
# Even for nonexistent files, should return bool without error
|
||||
result = proxy.exists_annotated_filepath("test.png[output]")
|
||||
assert isinstance(result, bool), f"Expected bool, got {type(result)}"
|
||||
|
||||
def test_models_dir_property_returns_string(self, proxy):
|
||||
"""Verify models_dir property returns valid path string."""
|
||||
result = proxy.models_dir
|
||||
assert isinstance(result, str), f"Expected str, got {type(result)}"
|
||||
assert len(result) > 0, "Models directory path is empty"
|
||||
|
||||
def test_models_dir_is_absolute_path(self, proxy):
|
||||
"""Verify models_dir returns an absolute path."""
|
||||
result = proxy.models_dir
|
||||
path = Path(result)
|
||||
assert path.is_absolute(), f"Path is not absolute: {result}"
|
||||
|
||||
def test_add_model_folder_path_runs_without_error(self, proxy):
|
||||
"""Verify add_model_folder_path executes without raising."""
|
||||
test_path = "/tmp/test_models_florence2"
|
||||
# Should not raise
|
||||
proxy.add_model_folder_path("TEST_FLORENCE2", test_path)
|
||||
|
||||
def test_get_folder_paths_returns_list(self, proxy):
|
||||
"""Verify get_folder_paths returns a list."""
|
||||
# Use known folder type that should exist
|
||||
result = proxy.get_folder_paths("checkpoints")
|
||||
assert isinstance(result, list), f"Expected list, got {type(result)}"
|
||||
|
||||
def test_get_folder_paths_checkpoints_not_empty(self, proxy):
|
||||
"""Verify checkpoints folder paths list is not empty."""
|
||||
result = proxy.get_folder_paths("checkpoints")
|
||||
# Should have at least one checkpoint path registered
|
||||
assert len(result) > 0, "Checkpoints folder paths is empty"
|
||||
|
||||
def test_sealed_child_safe_uses_rpc_without_importing_folder_paths(self, monkeypatch):
|
||||
monkeypatch.setenv("PYISOLATE_CHILD", "1")
|
||||
monkeypatch.setenv("PYISOLATE_IMPORT_TORCH", "0")
|
||||
|
||||
payload = capture_sealed_singleton_imports()
|
||||
|
||||
assert payload["temp_dir"] == "/sandbox/temp"
|
||||
assert payload["models_dir"] == "/sandbox/models"
|
||||
assert "folder_paths" not in payload["modules"]
|
||||
209
tests/isolation/test_host_policy.py
Normal file
209
tests/isolation/test_host_policy.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _write_pyproject(path: Path, content: str) -> None:
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def test_load_host_policy_defaults_when_pyproject_missing(tmp_path):
|
||||
from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy
|
||||
|
||||
policy = load_host_policy(tmp_path)
|
||||
|
||||
assert policy["sandbox_mode"] == DEFAULT_POLICY["sandbox_mode"]
|
||||
assert policy["allow_network"] == DEFAULT_POLICY["allow_network"]
|
||||
assert policy["writable_paths"] == DEFAULT_POLICY["writable_paths"]
|
||||
assert policy["readonly_paths"] == DEFAULT_POLICY["readonly_paths"]
|
||||
assert policy["whitelist"] == DEFAULT_POLICY["whitelist"]
|
||||
|
||||
|
||||
def test_load_host_policy_defaults_when_section_missing(tmp_path):
|
||||
from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy
|
||||
|
||||
_write_pyproject(
|
||||
tmp_path / "pyproject.toml",
|
||||
"""
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
policy = load_host_policy(tmp_path)
|
||||
assert policy["sandbox_mode"] == DEFAULT_POLICY["sandbox_mode"]
|
||||
assert policy["allow_network"] == DEFAULT_POLICY["allow_network"]
|
||||
assert policy["whitelist"] == {}
|
||||
|
||||
|
||||
def test_load_host_policy_reads_values(tmp_path):
|
||||
from comfy.isolation.host_policy import load_host_policy
|
||||
|
||||
_write_pyproject(
|
||||
tmp_path / "pyproject.toml",
|
||||
"""
|
||||
[tool.comfy.host]
|
||||
sandbox_mode = "disabled"
|
||||
allow_network = true
|
||||
writable_paths = ["/tmp/a", "/tmp/b"]
|
||||
readonly_paths = ["/opt/readonly"]
|
||||
|
||||
[tool.comfy.host.whitelist]
|
||||
ExampleNode = "*"
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
policy = load_host_policy(tmp_path)
|
||||
assert policy["sandbox_mode"] == "disabled"
|
||||
assert policy["allow_network"] is True
|
||||
assert policy["writable_paths"] == ["/tmp/a", "/tmp/b"]
|
||||
assert policy["readonly_paths"] == ["/opt/readonly"]
|
||||
assert policy["whitelist"] == {"ExampleNode": "*"}
|
||||
|
||||
|
||||
def test_load_host_policy_ignores_invalid_whitelist_type(tmp_path):
|
||||
from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy
|
||||
|
||||
_write_pyproject(
|
||||
tmp_path / "pyproject.toml",
|
||||
"""
|
||||
[tool.comfy.host]
|
||||
allow_network = true
|
||||
whitelist = ["bad"]
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
policy = load_host_policy(tmp_path)
|
||||
assert policy["allow_network"] is True
|
||||
assert policy["whitelist"] == DEFAULT_POLICY["whitelist"]
|
||||
|
||||
|
||||
def test_load_host_policy_ignores_invalid_sandbox_mode(tmp_path):
|
||||
from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy
|
||||
|
||||
_write_pyproject(
|
||||
tmp_path / "pyproject.toml",
|
||||
"""
|
||||
[tool.comfy.host]
|
||||
sandbox_mode = "surprise"
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
policy = load_host_policy(tmp_path)
|
||||
|
||||
assert policy["sandbox_mode"] == DEFAULT_POLICY["sandbox_mode"]
|
||||
|
||||
|
||||
def test_load_host_policy_uses_env_override_path(tmp_path, monkeypatch):
|
||||
from comfy.isolation.host_policy import load_host_policy
|
||||
|
||||
override_path = tmp_path / "host_policy_override.toml"
|
||||
_write_pyproject(
|
||||
override_path,
|
||||
"""
|
||||
[tool.comfy.host]
|
||||
sandbox_mode = "disabled"
|
||||
allow_network = true
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
monkeypatch.setenv("COMFY_HOST_POLICY_PATH", str(override_path))
|
||||
|
||||
policy = load_host_policy(tmp_path / "missing-root")
|
||||
|
||||
assert policy["sandbox_mode"] == "disabled"
|
||||
assert policy["allow_network"] is True
|
||||
|
||||
|
||||
def test_disallows_host_tmp_default_or_override_defaults(tmp_path):
|
||||
from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy
|
||||
|
||||
policy = load_host_policy(tmp_path)
|
||||
|
||||
assert "/tmp" not in DEFAULT_POLICY["writable_paths"]
|
||||
assert "/tmp" not in policy["writable_paths"]
|
||||
|
||||
|
||||
def test_disallows_host_tmp_default_or_override_config(tmp_path):
|
||||
from comfy.isolation.host_policy import load_host_policy
|
||||
|
||||
_write_pyproject(
|
||||
tmp_path / "pyproject.toml",
|
||||
"""
|
||||
[tool.comfy.host]
|
||||
writable_paths = ["/dev/shm", "/tmp", "/tmp/", "/work/cache"]
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
policy = load_host_policy(tmp_path)
|
||||
|
||||
assert policy["writable_paths"] == ["/dev/shm", "/work/cache"]
|
||||
|
||||
|
||||
def test_sealed_worker_ro_import_paths_defaults_off_and_parse(tmp_path):
|
||||
from comfy.isolation.host_policy import load_host_policy
|
||||
|
||||
policy = load_host_policy(tmp_path)
|
||||
assert policy["sealed_worker_ro_import_paths"] == []
|
||||
|
||||
_write_pyproject(
|
||||
tmp_path / "pyproject.toml",
|
||||
"""
|
||||
[tool.comfy.host]
|
||||
sealed_worker_ro_import_paths = ["/home/johnj/ComfyUI", "/opt/comfy-shared"]
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
policy = load_host_policy(tmp_path)
|
||||
assert policy["sealed_worker_ro_import_paths"] == [
|
||||
"/home/johnj/ComfyUI",
|
||||
"/opt/comfy-shared",
|
||||
]
|
||||
|
||||
|
||||
def test_sealed_worker_ro_import_paths_rejects_non_list_or_relative(tmp_path):
|
||||
from comfy.isolation.host_policy import load_host_policy
|
||||
|
||||
_write_pyproject(
|
||||
tmp_path / "pyproject.toml",
|
||||
"""
|
||||
[tool.comfy.host]
|
||||
sealed_worker_ro_import_paths = "/home/johnj/ComfyUI"
|
||||
""".strip(),
|
||||
)
|
||||
with pytest.raises(ValueError, match="must be a list of absolute paths"):
|
||||
load_host_policy(tmp_path)
|
||||
|
||||
_write_pyproject(
|
||||
tmp_path / "pyproject.toml",
|
||||
"""
|
||||
[tool.comfy.host]
|
||||
sealed_worker_ro_import_paths = ["relative/path"]
|
||||
""".strip(),
|
||||
)
|
||||
with pytest.raises(ValueError, match="entries must be absolute paths"):
|
||||
load_host_policy(tmp_path)
|
||||
|
||||
|
||||
def test_host_policy_path_override_controls_ro_import_paths(tmp_path, monkeypatch):
|
||||
from comfy.isolation.host_policy import load_host_policy
|
||||
|
||||
_write_pyproject(
|
||||
tmp_path / "pyproject.toml",
|
||||
"""
|
||||
[tool.comfy.host]
|
||||
sealed_worker_ro_import_paths = ["/ignored/base/path"]
|
||||
""".strip(),
|
||||
)
|
||||
override_path = tmp_path / "host_policy_override.toml"
|
||||
_write_pyproject(
|
||||
override_path,
|
||||
"""
|
||||
[tool.comfy.host]
|
||||
sealed_worker_ro_import_paths = ["/override/ro/path"]
|
||||
""".strip(),
|
||||
)
|
||||
monkeypatch.setenv("COMFY_HOST_POLICY_PATH", str(override_path))
|
||||
|
||||
policy = load_host_policy(tmp_path)
|
||||
assert policy["sealed_worker_ro_import_paths"] == ["/override/ro/path"]
|
||||
80
tests/isolation/test_init.py
Normal file
80
tests/isolation/test_init.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Unit tests for PyIsolate isolation system initialization."""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
from tests.isolation.singleton_boundary_helpers import (
|
||||
FakeSingletonRPC,
|
||||
reset_forbidden_singleton_modules,
|
||||
)
|
||||
|
||||
|
||||
def test_log_prefix():
|
||||
"""Verify LOG_PREFIX constant is correctly defined."""
|
||||
from comfy.isolation import LOG_PREFIX
|
||||
assert LOG_PREFIX == "]["
|
||||
assert isinstance(LOG_PREFIX, str)
|
||||
|
||||
|
||||
def test_module_initialization():
|
||||
"""Verify module initializes without errors."""
|
||||
isolation_pkg = importlib.import_module("comfy.isolation")
|
||||
assert hasattr(isolation_pkg, "LOG_PREFIX")
|
||||
assert hasattr(isolation_pkg, "initialize_proxies")
|
||||
|
||||
|
||||
class TestInitializeProxies:
|
||||
def test_initialize_proxies_runs_without_error(self):
|
||||
from comfy.isolation import initialize_proxies
|
||||
initialize_proxies()
|
||||
|
||||
def test_initialize_proxies_registers_folder_paths_proxy(self):
|
||||
from comfy.isolation import initialize_proxies
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
initialize_proxies()
|
||||
proxy = FolderPathsProxy()
|
||||
assert proxy is not None
|
||||
assert hasattr(proxy, "get_temp_directory")
|
||||
|
||||
def test_initialize_proxies_registers_model_management_proxy(self):
|
||||
from comfy.isolation import initialize_proxies
|
||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||
initialize_proxies()
|
||||
proxy = ModelManagementProxy()
|
||||
assert proxy is not None
|
||||
assert hasattr(proxy, "get_torch_device")
|
||||
|
||||
def test_initialize_proxies_can_be_called_multiple_times(self):
|
||||
from comfy.isolation import initialize_proxies
|
||||
initialize_proxies()
|
||||
initialize_proxies()
|
||||
initialize_proxies()
|
||||
|
||||
def test_dev_proxies_accessible_when_dev_mode(self, monkeypatch):
|
||||
"""Verify dev mode does not break core proxy initialization."""
|
||||
monkeypatch.setenv("PYISOLATE_DEV", "1")
|
||||
from comfy.isolation import initialize_proxies
|
||||
from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy
|
||||
from comfy.isolation.proxies.utils_proxy import UtilsProxy
|
||||
initialize_proxies()
|
||||
folder_proxy = FolderPathsProxy()
|
||||
utils_proxy = UtilsProxy()
|
||||
assert folder_proxy is not None
|
||||
assert utils_proxy is not None
|
||||
|
||||
def test_sealed_child_safe_initialize_proxies_avoids_real_utils_import(self, monkeypatch):
|
||||
monkeypatch.setenv("PYISOLATE_CHILD", "1")
|
||||
monkeypatch.setenv("PYISOLATE_IMPORT_TORCH", "0")
|
||||
reset_forbidden_singleton_modules()
|
||||
|
||||
from pyisolate._internal import rpc_protocol
|
||||
from comfy.isolation import initialize_proxies
|
||||
|
||||
fake_rpc = FakeSingletonRPC()
|
||||
monkeypatch.setattr(rpc_protocol, "get_child_rpc_instance", lambda: fake_rpc)
|
||||
|
||||
initialize_proxies()
|
||||
|
||||
assert "comfy.utils" not in sys.modules
|
||||
assert "folder_paths" not in sys.modules
|
||||
assert "comfy_execution.progress" not in sys.modules
|
||||
105
tests/isolation/test_internal_probe_node_assets.py
Normal file
105
tests/isolation/test_internal_probe_node_assets.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
COMFYUI_ROOT = Path(__file__).resolve().parents[2]
|
||||
ISOLATION_ROOT = COMFYUI_ROOT / "tests" / "isolation"
|
||||
PROBE_ROOT = ISOLATION_ROOT / "internal_probe_node"
|
||||
WORKFLOW_ROOT = ISOLATION_ROOT / "workflows"
|
||||
TOOLKIT_ROOT = COMFYUI_ROOT / "custom_nodes" / "ComfyUI-IsolationToolkit"
|
||||
|
||||
EXPECTED_PROBE_FILES = {
|
||||
"__init__.py",
|
||||
"probe_nodes.py",
|
||||
}
|
||||
EXPECTED_WORKFLOWS = {
|
||||
"internal_probe_preview_image_audio.json",
|
||||
"internal_probe_ui3d.json",
|
||||
}
|
||||
BANNED_REFERENCES = (
|
||||
"ComfyUI-IsolationToolkit",
|
||||
"toolkit_smoke_playlist",
|
||||
"run_isolation_toolkit_smoke.sh",
|
||||
)
|
||||
|
||||
|
||||
def _text_assets() -> list[Path]:
|
||||
return sorted(list(PROBE_ROOT.rglob("*.py")) + list(WORKFLOW_ROOT.glob("internal_probe_*.json")))
|
||||
|
||||
|
||||
def _load_probe_package():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"internal_probe_node",
|
||||
PROBE_ROOT / "__init__.py",
|
||||
submodule_search_locations=[str(PROBE_ROOT)],
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec is not None
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_inventory_is_minimal_and_isolation_owned():
|
||||
assert PROBE_ROOT.is_dir()
|
||||
assert WORKFLOW_ROOT.is_dir()
|
||||
assert PROBE_ROOT.is_relative_to(ISOLATION_ROOT)
|
||||
assert WORKFLOW_ROOT.is_relative_to(ISOLATION_ROOT)
|
||||
assert not PROBE_ROOT.is_relative_to(TOOLKIT_ROOT)
|
||||
|
||||
probe_files = {path.name for path in PROBE_ROOT.iterdir() if path.is_file()}
|
||||
workflow_files = {path.name for path in WORKFLOW_ROOT.glob("internal_probe_*.json")}
|
||||
|
||||
assert probe_files == EXPECTED_PROBE_FILES
|
||||
assert workflow_files == EXPECTED_WORKFLOWS
|
||||
|
||||
module = _load_probe_package()
|
||||
mappings = module.NODE_CLASS_MAPPINGS
|
||||
|
||||
assert sorted(mappings.keys()) == [
|
||||
"InternalIsolationProbeAudio",
|
||||
"InternalIsolationProbeImage",
|
||||
"InternalIsolationProbeUI3D",
|
||||
]
|
||||
|
||||
preview_workflow = json.loads(
|
||||
(WORKFLOW_ROOT / "internal_probe_preview_image_audio.json").read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
)
|
||||
ui3d_workflow = json.loads(
|
||||
(WORKFLOW_ROOT / "internal_probe_ui3d.json").read_text(encoding="utf-8")
|
||||
)
|
||||
|
||||
assert [preview_workflow[node_id]["class_type"] for node_id in ("1", "2")] == [
|
||||
"InternalIsolationProbeImage",
|
||||
"InternalIsolationProbeAudio",
|
||||
]
|
||||
assert [ui3d_workflow[node_id]["class_type"] for node_id in ("1",)] == [
|
||||
"InternalIsolationProbeUI3D",
|
||||
]
|
||||
|
||||
|
||||
def test_zero_toolkit_references_in_probe_assets():
|
||||
for asset in _text_assets():
|
||||
content = asset.read_text(encoding="utf-8")
|
||||
for banned in BANNED_REFERENCES:
|
||||
assert banned not in content, f"{asset} unexpectedly references {banned}"
|
||||
|
||||
|
||||
def test_replacement_contract_has_zero_toolkit_references():
|
||||
contract_assets = [
|
||||
*(PROBE_ROOT.rglob("*.py")),
|
||||
*WORKFLOW_ROOT.glob("internal_probe_*.json"),
|
||||
ISOLATION_ROOT / "stage_internal_probe_node.py",
|
||||
ISOLATION_ROOT / "internal_probe_host_policy.toml",
|
||||
]
|
||||
|
||||
for asset in sorted(contract_assets):
|
||||
assert asset.exists(), f"Missing replacement-contract asset: {asset}"
|
||||
content = asset.read_text(encoding="utf-8")
|
||||
for banned in BANNED_REFERENCES:
|
||||
assert banned not in content, f"{asset} unexpectedly references {banned}"
|
||||
180
tests/isolation/test_internal_probe_node_loading.py
Normal file
180
tests/isolation/test_internal_probe_node_loading.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import nodes
|
||||
from tests.isolation.stage_internal_probe_node import (
|
||||
PROBE_NODE_NAME,
|
||||
stage_probe_node,
|
||||
staged_probe_node,
|
||||
)
|
||||
|
||||
|
||||
COMFYUI_ROOT = Path(__file__).resolve().parents[2]
|
||||
ISOLATION_ROOT = COMFYUI_ROOT / "tests" / "isolation"
|
||||
PROBE_SOURCE_ROOT = ISOLATION_ROOT / "internal_probe_node"
|
||||
EXPECTED_NODE_IDS = [
|
||||
"InternalIsolationProbeAudio",
|
||||
"InternalIsolationProbeImage",
|
||||
"InternalIsolationProbeUI3D",
|
||||
]
|
||||
|
||||
CLIENT_SCRIPT = """
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pyisolate._internal.client # noqa: F401 # triggers snapshot bootstrap
|
||||
|
||||
module_path = os.environ["PYISOLATE_MODULE_PATH"]
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"internal_probe_node",
|
||||
os.path.join(module_path, "__init__.py"),
|
||||
submodule_search_locations=[module_path],
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec is not None
|
||||
assert spec.loader is not None
|
||||
sys.modules["internal_probe_node"] = module
|
||||
spec.loader.exec_module(module)
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"sys_path": list(sys.path),
|
||||
"module_path": module_path,
|
||||
"node_ids": sorted(module.NODE_CLASS_MAPPINGS.keys()),
|
||||
}
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def _run_client_process(env: dict[str, str]) -> dict:
|
||||
pythonpath_parts = [str(COMFYUI_ROOT)]
|
||||
existing = env.get("PYTHONPATH", "")
|
||||
if existing:
|
||||
pythonpath_parts.append(existing)
|
||||
env["PYTHONPATH"] = ":".join(pythonpath_parts)
|
||||
|
||||
result = subprocess.run( # noqa: S603
|
||||
[sys.executable, "-c", CLIENT_SCRIPT],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
return json.loads(result.stdout.strip().splitlines()[-1])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def staged_probe_module(tmp_path: Path) -> tuple[Path, Path]:
|
||||
staged_comfy_root = tmp_path / "ComfyUI"
|
||||
module_path = staged_comfy_root / "custom_nodes" / "InternalIsolationProbeNode"
|
||||
shutil.copytree(PROBE_SOURCE_ROOT, module_path)
|
||||
return staged_comfy_root, module_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_staged_probe_node_discovered(staged_probe_module: tuple[Path, Path]) -> None:
|
||||
_, module_path = staged_probe_module
|
||||
class_mappings_snapshot = dict(nodes.NODE_CLASS_MAPPINGS)
|
||||
display_name_snapshot = dict(nodes.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
loaded_module_dirs_snapshot = dict(nodes.LOADED_MODULE_DIRS)
|
||||
|
||||
try:
|
||||
ignore = set(nodes.NODE_CLASS_MAPPINGS.keys())
|
||||
loaded = await nodes.load_custom_node(
|
||||
str(module_path), ignore=ignore, module_parent="custom_nodes"
|
||||
)
|
||||
|
||||
assert loaded is True
|
||||
assert nodes.LOADED_MODULE_DIRS["InternalIsolationProbeNode"] == str(
|
||||
module_path.resolve()
|
||||
)
|
||||
|
||||
for node_id in EXPECTED_NODE_IDS:
|
||||
assert node_id in nodes.NODE_CLASS_MAPPINGS
|
||||
node_cls = nodes.NODE_CLASS_MAPPINGS[node_id]
|
||||
assert (
|
||||
getattr(node_cls, "RELATIVE_PYTHON_MODULE", None)
|
||||
== "custom_nodes.InternalIsolationProbeNode"
|
||||
)
|
||||
finally:
|
||||
nodes.NODE_CLASS_MAPPINGS.clear()
|
||||
nodes.NODE_CLASS_MAPPINGS.update(class_mappings_snapshot)
|
||||
nodes.NODE_DISPLAY_NAME_MAPPINGS.clear()
|
||||
nodes.NODE_DISPLAY_NAME_MAPPINGS.update(display_name_snapshot)
|
||||
nodes.LOADED_MODULE_DIRS.clear()
|
||||
nodes.LOADED_MODULE_DIRS.update(loaded_module_dirs_snapshot)
|
||||
|
||||
|
||||
def test_staged_probe_node_module_path_is_valid_for_child_bootstrap(
|
||||
tmp_path: Path, staged_probe_module: tuple[Path, Path]
|
||||
) -> None:
|
||||
staged_comfy_root, module_path = staged_probe_module
|
||||
snapshot = {
|
||||
"sys_path": [str(COMFYUI_ROOT), "/host/lib1", "/host/lib2"],
|
||||
"sys_executable": sys.executable,
|
||||
"sys_prefix": sys.prefix,
|
||||
"environment": {},
|
||||
}
|
||||
snapshot_path = tmp_path / "snapshot.json"
|
||||
snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8")
|
||||
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"PYISOLATE_CHILD": "1",
|
||||
"PYISOLATE_HOST_SNAPSHOT": str(snapshot_path),
|
||||
"PYISOLATE_MODULE_PATH": str(module_path),
|
||||
}
|
||||
)
|
||||
|
||||
payload = _run_client_process(env)
|
||||
|
||||
assert payload["module_path"] == str(module_path)
|
||||
assert payload["node_ids"] == EXPECTED_NODE_IDS
|
||||
assert str(COMFYUI_ROOT) in payload["sys_path"]
|
||||
assert str(staged_comfy_root) not in payload["sys_path"]
|
||||
|
||||
|
||||
def test_stage_probe_node_stages_only_under_explicit_root(tmp_path: Path) -> None:
|
||||
comfy_root = tmp_path / "sandbox-root"
|
||||
|
||||
module_path = stage_probe_node(comfy_root)
|
||||
|
||||
assert module_path == comfy_root / "custom_nodes" / PROBE_NODE_NAME
|
||||
assert module_path.is_dir()
|
||||
assert (module_path / "__init__.py").is_file()
|
||||
assert (module_path / "probe_nodes.py").is_file()
|
||||
assert (module_path / "pyproject.toml").is_file()
|
||||
|
||||
|
||||
def test_staged_probe_node_context_cleans_up_temp_root() -> None:
|
||||
with staged_probe_node() as module_path:
|
||||
staging_root = module_path.parents[1]
|
||||
assert module_path.name == PROBE_NODE_NAME
|
||||
assert module_path.is_dir()
|
||||
assert staging_root.is_dir()
|
||||
|
||||
assert not staging_root.exists()
|
||||
|
||||
|
||||
def test_stage_script_requires_explicit_target_root() -> None:
|
||||
result = subprocess.run( # noqa: S603
|
||||
[sys.executable, str(ISOLATION_ROOT / "stage_internal_probe_node.py")],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
assert result.returncode != 0
|
||||
assert "--target-root" in result.stderr
|
||||
434
tests/isolation/test_manifest_loader_cache.py
Normal file
434
tests/isolation/test_manifest_loader_cache.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Unit tests for manifest_loader.py cache functions.
|
||||
|
||||
Phase 1 tests verify:
|
||||
1. Cache miss on first run (no cache exists)
|
||||
2. Cache hit when nothing changes
|
||||
3. Invalidation on .py file touch
|
||||
4. Invalidation on manifest change
|
||||
5. Cache location correctness (in venv_root, NOT in custom_nodes)
|
||||
6. Corrupt cache handling (graceful failure)
|
||||
|
||||
These tests verify the cache implementation is correct BEFORE it's activated
|
||||
in extension_loader.py (Phase 2).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
|
||||
|
||||
class TestComputeCacheKey:
|
||||
"""Tests for compute_cache_key() function."""
|
||||
|
||||
def test_key_includes_manifest_content(self, tmp_path: Path) -> None:
|
||||
"""Cache key changes when manifest content changes."""
|
||||
from comfy.isolation.manifest_loader import compute_cache_key
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
|
||||
# Initial manifest
|
||||
manifest.write_text("isolated: true\ndependencies: []\n")
|
||||
key1 = compute_cache_key(node_dir, manifest)
|
||||
|
||||
# Modified manifest
|
||||
manifest.write_text("isolated: true\ndependencies: [numpy]\n")
|
||||
key2 = compute_cache_key(node_dir, manifest)
|
||||
|
||||
assert key1 != key2, "Key should change when manifest content changes"
|
||||
|
||||
def test_key_includes_py_file_mtime(self, tmp_path: Path) -> None:
|
||||
"""Cache key changes when any .py file is touched."""
|
||||
from comfy.isolation.manifest_loader import compute_cache_key
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
|
||||
py_file = node_dir / "nodes.py"
|
||||
py_file.write_text("# test code")
|
||||
|
||||
key1 = compute_cache_key(node_dir, manifest)
|
||||
|
||||
# Wait a moment to ensure mtime changes
|
||||
time.sleep(0.01)
|
||||
py_file.write_text("# modified code")
|
||||
|
||||
key2 = compute_cache_key(node_dir, manifest)
|
||||
|
||||
assert key1 != key2, "Key should change when .py file mtime changes"
|
||||
|
||||
def test_key_includes_python_version(self, tmp_path: Path) -> None:
|
||||
"""Cache key changes when Python version changes."""
|
||||
from comfy.isolation.manifest_loader import compute_cache_key
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
|
||||
key1 = compute_cache_key(node_dir, manifest)
|
||||
|
||||
# Mock different Python version
|
||||
with mock.patch.object(sys, "version", "3.99.0 (fake)"):
|
||||
key2 = compute_cache_key(node_dir, manifest)
|
||||
|
||||
assert key1 != key2, "Key should change when Python version changes"
|
||||
|
||||
def test_key_includes_pyisolate_version(self, tmp_path: Path) -> None:
|
||||
"""Cache key changes when PyIsolate version changes."""
|
||||
from comfy.isolation.manifest_loader import compute_cache_key
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
|
||||
key1 = compute_cache_key(node_dir, manifest)
|
||||
|
||||
# Mock different pyisolate version
|
||||
with mock.patch.dict(sys.modules, {"pyisolate": mock.MagicMock(__version__="99.99.99")}):
|
||||
# Need to reimport to pick up the mock
|
||||
import importlib
|
||||
from comfy.isolation import manifest_loader
|
||||
importlib.reload(manifest_loader)
|
||||
key2 = manifest_loader.compute_cache_key(node_dir, manifest)
|
||||
|
||||
# Keys should be different (though the mock approach is tricky)
|
||||
# At minimum, verify key is a valid hex string
|
||||
assert len(key1) == 16, "Key should be 16 hex characters"
|
||||
assert all(c in "0123456789abcdef" for c in key1), "Key should be hex"
|
||||
assert len(key2) == 16, "Key should be 16 hex characters"
|
||||
assert all(c in "0123456789abcdef" for c in key2), "Key should be hex"
|
||||
|
||||
def test_key_excludes_pycache(self, tmp_path: Path) -> None:
|
||||
"""Cache key ignores __pycache__ directory changes."""
|
||||
from comfy.isolation.manifest_loader import compute_cache_key
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
|
||||
py_file = node_dir / "nodes.py"
|
||||
py_file.write_text("# test code")
|
||||
|
||||
key1 = compute_cache_key(node_dir, manifest)
|
||||
|
||||
# Add __pycache__ file
|
||||
pycache = node_dir / "__pycache__"
|
||||
pycache.mkdir()
|
||||
(pycache / "nodes.cpython-310.pyc").write_bytes(b"compiled")
|
||||
|
||||
key2 = compute_cache_key(node_dir, manifest)
|
||||
|
||||
assert key1 == key2, "Key should NOT change when __pycache__ modified"
|
||||
|
||||
def test_key_is_deterministic(self, tmp_path: Path) -> None:
|
||||
"""Same inputs produce same key."""
|
||||
from comfy.isolation.manifest_loader import compute_cache_key
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
(node_dir / "nodes.py").write_text("# code")
|
||||
|
||||
key1 = compute_cache_key(node_dir, manifest)
|
||||
key2 = compute_cache_key(node_dir, manifest)
|
||||
|
||||
assert key1 == key2, "Key should be deterministic"
|
||||
|
||||
|
||||
class TestGetCachePath:
|
||||
"""Tests for get_cache_path() function."""
|
||||
|
||||
def test_returns_correct_paths(self, tmp_path: Path) -> None:
|
||||
"""Cache paths are in venv_root, not in node_dir."""
|
||||
from comfy.isolation.manifest_loader import get_cache_path
|
||||
|
||||
node_dir = tmp_path / "custom_nodes" / "MyNode"
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
key_file, data_file = get_cache_path(node_dir, venv_root)
|
||||
|
||||
assert key_file == venv_root / "MyNode" / "cache" / "cache_key"
|
||||
assert data_file == venv_root / "MyNode" / "cache" / "node_info.json"
|
||||
|
||||
def test_cache_not_in_custom_nodes(self, tmp_path: Path) -> None:
|
||||
"""Verify cache is NOT stored in custom_nodes directory."""
|
||||
from comfy.isolation.manifest_loader import get_cache_path
|
||||
|
||||
node_dir = tmp_path / "custom_nodes" / "MyNode"
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
key_file, data_file = get_cache_path(node_dir, venv_root)
|
||||
|
||||
# Neither path should be under node_dir
|
||||
assert not str(key_file).startswith(str(node_dir))
|
||||
assert not str(data_file).startswith(str(node_dir))
|
||||
|
||||
|
||||
class TestIsCacheValid:
|
||||
"""Tests for is_cache_valid() function."""
|
||||
|
||||
def test_false_when_no_cache_exists(self, tmp_path: Path) -> None:
|
||||
"""Returns False when cache files don't exist."""
|
||||
from comfy.isolation.manifest_loader import is_cache_valid
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
assert is_cache_valid(node_dir, manifest, venv_root) is False
|
||||
|
||||
def test_true_when_cache_matches(self, tmp_path: Path) -> None:
|
||||
"""Returns True when cache key matches current state."""
|
||||
from comfy.isolation.manifest_loader import (
|
||||
compute_cache_key,
|
||||
get_cache_path,
|
||||
is_cache_valid,
|
||||
)
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
(node_dir / "nodes.py").write_text("# code")
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
# Create valid cache
|
||||
cache_key = compute_cache_key(node_dir, manifest)
|
||||
key_file, data_file = get_cache_path(node_dir, venv_root)
|
||||
key_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
key_file.write_text(cache_key)
|
||||
data_file.write_text("{}")
|
||||
|
||||
assert is_cache_valid(node_dir, manifest, venv_root) is True
|
||||
|
||||
def test_false_when_key_mismatch(self, tmp_path: Path) -> None:
|
||||
"""Returns False when stored key doesn't match current state."""
|
||||
from comfy.isolation.manifest_loader import get_cache_path, is_cache_valid
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
# Create cache with wrong key
|
||||
key_file, data_file = get_cache_path(node_dir, venv_root)
|
||||
key_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
key_file.write_text("wrong_key_12345")
|
||||
data_file.write_text("{}")
|
||||
|
||||
assert is_cache_valid(node_dir, manifest, venv_root) is False
|
||||
|
||||
def test_false_when_data_file_missing(self, tmp_path: Path) -> None:
|
||||
"""Returns False when node_info.json is missing."""
|
||||
from comfy.isolation.manifest_loader import (
|
||||
compute_cache_key,
|
||||
get_cache_path,
|
||||
is_cache_valid,
|
||||
)
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
# Create only key file, not data file
|
||||
cache_key = compute_cache_key(node_dir, manifest)
|
||||
key_file, _ = get_cache_path(node_dir, venv_root)
|
||||
key_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
key_file.write_text(cache_key)
|
||||
|
||||
assert is_cache_valid(node_dir, manifest, venv_root) is False
|
||||
|
||||
def test_invalidation_on_py_change(self, tmp_path: Path) -> None:
|
||||
"""Cache invalidates when .py file is modified."""
|
||||
from comfy.isolation.manifest_loader import (
|
||||
compute_cache_key,
|
||||
get_cache_path,
|
||||
is_cache_valid,
|
||||
)
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
py_file = node_dir / "nodes.py"
|
||||
py_file.write_text("# original")
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
# Create valid cache
|
||||
cache_key = compute_cache_key(node_dir, manifest)
|
||||
key_file, data_file = get_cache_path(node_dir, venv_root)
|
||||
key_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
key_file.write_text(cache_key)
|
||||
data_file.write_text("{}")
|
||||
|
||||
# Verify cache is valid initially
|
||||
assert is_cache_valid(node_dir, manifest, venv_root) is True
|
||||
|
||||
# Modify .py file
|
||||
time.sleep(0.01) # Ensure mtime changes
|
||||
py_file.write_text("# modified")
|
||||
|
||||
# Cache should now be invalid
|
||||
assert is_cache_valid(node_dir, manifest, venv_root) is False
|
||||
|
||||
|
||||
class TestLoadFromCache:
|
||||
"""Tests for load_from_cache() function."""
|
||||
|
||||
def test_returns_none_when_no_cache(self, tmp_path: Path) -> None:
|
||||
"""Returns None when cache doesn't exist."""
|
||||
from comfy.isolation.manifest_loader import load_from_cache
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
assert load_from_cache(node_dir, venv_root) is None
|
||||
|
||||
def test_returns_data_when_valid(self, tmp_path: Path) -> None:
|
||||
"""Returns cached data when file exists and is valid JSON."""
|
||||
from comfy.isolation.manifest_loader import get_cache_path, load_from_cache
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
test_data = {"TestNode": {"inputs": [], "outputs": []}}
|
||||
|
||||
_, data_file = get_cache_path(node_dir, venv_root)
|
||||
data_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
data_file.write_text(json.dumps(test_data))
|
||||
|
||||
result = load_from_cache(node_dir, venv_root)
|
||||
assert result == test_data
|
||||
|
||||
def test_returns_none_on_corrupt_json(self, tmp_path: Path) -> None:
|
||||
"""Returns None when JSON is corrupt."""
|
||||
from comfy.isolation.manifest_loader import get_cache_path, load_from_cache
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
_, data_file = get_cache_path(node_dir, venv_root)
|
||||
data_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
data_file.write_text("{ corrupt json }")
|
||||
|
||||
assert load_from_cache(node_dir, venv_root) is None
|
||||
|
||||
def test_returns_none_on_invalid_structure(self, tmp_path: Path) -> None:
|
||||
"""Returns None when data is not a dict."""
|
||||
from comfy.isolation.manifest_loader import get_cache_path, load_from_cache
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
_, data_file = get_cache_path(node_dir, venv_root)
|
||||
data_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
data_file.write_text("[1, 2, 3]") # Array, not dict
|
||||
|
||||
assert load_from_cache(node_dir, venv_root) is None
|
||||
|
||||
|
||||
class TestSaveToCache:
|
||||
"""Tests for save_to_cache() function."""
|
||||
|
||||
def test_creates_cache_directory(self, tmp_path: Path) -> None:
|
||||
"""Creates cache directory if it doesn't exist."""
|
||||
from comfy.isolation.manifest_loader import get_cache_path, save_to_cache
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
save_to_cache(node_dir, venv_root, {"TestNode": {}}, manifest)
|
||||
|
||||
key_file, data_file = get_cache_path(node_dir, venv_root)
|
||||
assert key_file.parent.exists()
|
||||
|
||||
def test_writes_both_files(self, tmp_path: Path) -> None:
|
||||
"""Writes both cache_key and node_info.json."""
|
||||
from comfy.isolation.manifest_loader import get_cache_path, save_to_cache
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
save_to_cache(node_dir, venv_root, {"TestNode": {"key": "value"}}, manifest)
|
||||
|
||||
key_file, data_file = get_cache_path(node_dir, venv_root)
|
||||
assert key_file.exists()
|
||||
assert data_file.exists()
|
||||
|
||||
def test_data_is_valid_json(self, tmp_path: Path) -> None:
|
||||
"""Written data can be parsed as JSON."""
|
||||
from comfy.isolation.manifest_loader import get_cache_path, save_to_cache
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
test_data = {"TestNode": {"inputs": ["IMAGE"], "outputs": ["IMAGE"]}}
|
||||
save_to_cache(node_dir, venv_root, test_data, manifest)
|
||||
|
||||
_, data_file = get_cache_path(node_dir, venv_root)
|
||||
loaded = json.loads(data_file.read_text())
|
||||
assert loaded == test_data
|
||||
|
||||
def test_roundtrip_with_validation(self, tmp_path: Path) -> None:
|
||||
"""Saved cache is immediately valid."""
|
||||
from comfy.isolation.manifest_loader import (
|
||||
is_cache_valid,
|
||||
load_from_cache,
|
||||
save_to_cache,
|
||||
)
|
||||
|
||||
node_dir = tmp_path / "test_node"
|
||||
node_dir.mkdir()
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
(node_dir / "nodes.py").write_text("# code")
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
test_data = {"TestNode": {"foo": "bar"}}
|
||||
save_to_cache(node_dir, venv_root, test_data, manifest)
|
||||
|
||||
assert is_cache_valid(node_dir, manifest, venv_root) is True
|
||||
assert load_from_cache(node_dir, venv_root) == test_data
|
||||
|
||||
def test_cache_not_in_custom_nodes(self, tmp_path: Path) -> None:
|
||||
"""Verify no files written to custom_nodes directory."""
|
||||
from comfy.isolation.manifest_loader import save_to_cache
|
||||
|
||||
node_dir = tmp_path / "custom_nodes" / "MyNode"
|
||||
node_dir.mkdir(parents=True)
|
||||
manifest = node_dir / "pyisolate.yaml"
|
||||
manifest.write_text("isolated: true\n")
|
||||
venv_root = tmp_path / ".pyisolate_venvs"
|
||||
|
||||
save_to_cache(node_dir, venv_root, {"TestNode": {}}, manifest)
|
||||
|
||||
# Check nothing was created under node_dir
|
||||
for item in node_dir.iterdir():
|
||||
assert item.name == "pyisolate.yaml", f"Unexpected file in node_dir: {item}"
|
||||
86
tests/isolation/test_manifest_loader_discovery.py
Normal file
86
tests/isolation/test_manifest_loader_discovery.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def _write_manifest(path: Path, *, standalone: bool = False) -> None:
|
||||
lines = [
|
||||
"[project]",
|
||||
'name = "test-node"',
|
||||
'version = "0.1.0"',
|
||||
"",
|
||||
"[tool.comfy.isolation]",
|
||||
"can_isolate = true",
|
||||
"share_torch = false",
|
||||
]
|
||||
if standalone:
|
||||
lines.append("standalone = true")
|
||||
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def _load_manifest_loader(custom_nodes_root: Path):
|
||||
folder_paths = ModuleType("folder_paths")
|
||||
folder_paths.base_path = str(custom_nodes_root)
|
||||
folder_paths.get_folder_paths = lambda kind: [str(custom_nodes_root)] if kind == "custom_nodes" else []
|
||||
sys.modules["folder_paths"] = folder_paths
|
||||
|
||||
if "comfy.isolation" not in sys.modules:
|
||||
iso_mod = ModuleType("comfy.isolation")
|
||||
iso_mod.__path__ = [ # type: ignore[attr-defined]
|
||||
str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation")
|
||||
]
|
||||
iso_mod.__package__ = "comfy.isolation"
|
||||
sys.modules["comfy.isolation"] = iso_mod
|
||||
|
||||
sys.modules.pop("comfy.isolation.manifest_loader", None)
|
||||
|
||||
import comfy.isolation.manifest_loader as manifest_loader
|
||||
|
||||
return importlib.reload(manifest_loader)
|
||||
|
||||
|
||||
def test_finds_top_level_isolation_manifest(tmp_path: Path) -> None:
|
||||
node_dir = tmp_path / "TopLevelNode"
|
||||
node_dir.mkdir(parents=True)
|
||||
_write_manifest(node_dir / "pyproject.toml")
|
||||
|
||||
manifest_loader = _load_manifest_loader(tmp_path)
|
||||
manifests = manifest_loader.find_manifest_directories()
|
||||
|
||||
assert manifests == [(node_dir, node_dir / "pyproject.toml")]
|
||||
|
||||
|
||||
def test_ignores_nested_manifest_without_standalone_flag(tmp_path: Path) -> None:
|
||||
toolkit_dir = tmp_path / "ToolkitNode"
|
||||
toolkit_dir.mkdir(parents=True)
|
||||
_write_manifest(toolkit_dir / "pyproject.toml")
|
||||
|
||||
nested_dir = toolkit_dir / "packages" / "nested_fixture"
|
||||
nested_dir.mkdir(parents=True)
|
||||
_write_manifest(nested_dir / "pyproject.toml", standalone=False)
|
||||
|
||||
manifest_loader = _load_manifest_loader(tmp_path)
|
||||
manifests = manifest_loader.find_manifest_directories()
|
||||
|
||||
assert manifests == [(toolkit_dir, toolkit_dir / "pyproject.toml")]
|
||||
|
||||
|
||||
def test_finds_nested_standalone_manifest(tmp_path: Path) -> None:
|
||||
toolkit_dir = tmp_path / "ToolkitNode"
|
||||
toolkit_dir.mkdir(parents=True)
|
||||
_write_manifest(toolkit_dir / "pyproject.toml")
|
||||
|
||||
nested_dir = toolkit_dir / "packages" / "uv_sealed_worker"
|
||||
nested_dir.mkdir(parents=True)
|
||||
_write_manifest(nested_dir / "pyproject.toml", standalone=True)
|
||||
|
||||
manifest_loader = _load_manifest_loader(tmp_path)
|
||||
manifests = manifest_loader.find_manifest_directories()
|
||||
|
||||
assert manifests == [
|
||||
(toolkit_dir, toolkit_dir / "pyproject.toml"),
|
||||
(nested_dir, nested_dir / "pyproject.toml"),
|
||||
]
|
||||
50
tests/isolation/test_model_management_proxy.py
Normal file
50
tests/isolation/test_model_management_proxy.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Unit tests for ModelManagementProxy."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy
|
||||
|
||||
|
||||
class TestModelManagementProxy:
|
||||
"""Test ModelManagementProxy methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def proxy(self):
|
||||
"""Create a ModelManagementProxy instance for testing."""
|
||||
return ModelManagementProxy()
|
||||
|
||||
def test_get_torch_device_returns_device(self, proxy):
|
||||
"""Verify get_torch_device returns a torch.device object."""
|
||||
result = proxy.get_torch_device()
|
||||
assert isinstance(result, torch.device), f"Expected torch.device, got {type(result)}"
|
||||
|
||||
def test_get_torch_device_is_valid(self, proxy):
|
||||
"""Verify get_torch_device returns a valid device (cpu or cuda)."""
|
||||
result = proxy.get_torch_device()
|
||||
assert result.type in ("cpu", "cuda"), f"Unexpected device type: {result.type}"
|
||||
|
||||
def test_get_torch_device_name_returns_string(self, proxy):
|
||||
"""Verify get_torch_device_name returns a non-empty string."""
|
||||
device = proxy.get_torch_device()
|
||||
result = proxy.get_torch_device_name(device)
|
||||
assert isinstance(result, str), f"Expected str, got {type(result)}"
|
||||
assert len(result) > 0, "Device name is empty"
|
||||
|
||||
def test_get_torch_device_name_with_cpu(self, proxy):
|
||||
"""Verify get_torch_device_name works with CPU device."""
|
||||
cpu_device = torch.device("cpu")
|
||||
result = proxy.get_torch_device_name(cpu_device)
|
||||
assert isinstance(result, str), f"Expected str, got {type(result)}"
|
||||
assert "cpu" in result.lower(), f"Expected 'cpu' in device name, got: {result}"
|
||||
|
||||
def test_get_torch_device_name_with_cuda_if_available(self, proxy):
|
||||
"""Verify get_torch_device_name works with CUDA device if available."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
cuda_device = torch.device("cuda:0")
|
||||
result = proxy.get_torch_device_name(cuda_device)
|
||||
assert isinstance(result, str), f"Expected str, got {type(result)}"
|
||||
# Should contain device identifier
|
||||
assert len(result) > 0, "CUDA device name is empty"
|
||||
93
tests/isolation/test_path_helpers.py
Normal file
93
tests/isolation/test_path_helpers.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from pyisolate.path_helpers import build_child_sys_path, serialize_host_snapshot
|
||||
|
||||
|
||||
def test_serialize_host_snapshot_includes_expected_keys(tmp_path: Path, monkeypatch) -> None:
|
||||
output = tmp_path / "snapshot.json"
|
||||
monkeypatch.setenv("EXTRA_FLAG", "1")
|
||||
snapshot = serialize_host_snapshot(output_path=output, extra_env_keys=["EXTRA_FLAG"])
|
||||
|
||||
assert "sys_path" in snapshot
|
||||
assert "sys_executable" in snapshot
|
||||
assert "sys_prefix" in snapshot
|
||||
assert "environment" in snapshot
|
||||
assert output.exists()
|
||||
assert snapshot["environment"].get("EXTRA_FLAG") == "1"
|
||||
|
||||
persisted = json.loads(output.read_text(encoding="utf-8"))
|
||||
assert persisted["sys_path"] == snapshot["sys_path"]
|
||||
|
||||
|
||||
def test_build_child_sys_path_preserves_host_order() -> None:
|
||||
host_paths = ["/host/root", "/host/site-packages"]
|
||||
extra_paths = ["/node/.venv/lib/python3.12/site-packages"]
|
||||
result = build_child_sys_path(host_paths, extra_paths, preferred_root=None)
|
||||
assert result == host_paths + extra_paths
|
||||
|
||||
|
||||
def test_build_child_sys_path_inserts_comfy_root_when_missing() -> None:
|
||||
host_paths = ["/host/site-packages"]
|
||||
comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI")
|
||||
extra_paths: list[str] = []
|
||||
result = build_child_sys_path(host_paths, extra_paths, preferred_root=comfy_root)
|
||||
assert result[0] == comfy_root
|
||||
assert result[1:] == host_paths
|
||||
|
||||
|
||||
def test_build_child_sys_path_deduplicates_entries(tmp_path: Path) -> None:
|
||||
path_a = str(tmp_path / "a")
|
||||
path_b = str(tmp_path / "b")
|
||||
host_paths = [path_a, path_b]
|
||||
extra_paths = [path_a, path_b, str(tmp_path / "c")]
|
||||
result = build_child_sys_path(host_paths, extra_paths)
|
||||
assert result == [path_a, path_b, str(tmp_path / "c")]
|
||||
|
||||
|
||||
def test_build_child_sys_path_skips_duplicate_comfy_root() -> None:
|
||||
comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI")
|
||||
host_paths = [comfy_root, "/host/other"]
|
||||
result = build_child_sys_path(host_paths, extra_paths=[], preferred_root=comfy_root)
|
||||
assert result == host_paths
|
||||
|
||||
|
||||
def test_child_import_succeeds_after_path_unification(tmp_path: Path, monkeypatch) -> None:
|
||||
host_root = tmp_path / "host"
|
||||
utils_pkg = host_root / "utils"
|
||||
app_pkg = host_root / "app"
|
||||
utils_pkg.mkdir(parents=True)
|
||||
app_pkg.mkdir(parents=True)
|
||||
|
||||
(utils_pkg / "__init__.py").write_text("from . import install_util\n", encoding="utf-8")
|
||||
(utils_pkg / "install_util.py").write_text("VALUE = 'hello'\n", encoding="utf-8")
|
||||
(app_pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(app_pkg / "frontend_management.py").write_text(
|
||||
"from utils import install_util\nVALUE = install_util.VALUE\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
child_only = tmp_path / "child_only"
|
||||
child_only.mkdir()
|
||||
|
||||
target_module = "app.frontend_management"
|
||||
for name in [n for n in list(sys.modules) if n.startswith("app") or n.startswith("utils")]:
|
||||
sys.modules.pop(name)
|
||||
|
||||
monkeypatch.setattr(sys, "path", [str(child_only)])
|
||||
with pytest.raises(ModuleNotFoundError):
|
||||
__import__(target_module)
|
||||
|
||||
for name in [n for n in list(sys.modules) if n.startswith("app") or n.startswith("utils")]:
|
||||
sys.modules.pop(name)
|
||||
|
||||
unified = build_child_sys_path([], [], preferred_root=str(host_root))
|
||||
monkeypatch.setattr(sys, "path", unified)
|
||||
module = __import__(target_module, fromlist=["VALUE"])
|
||||
assert module.VALUE == "hello"
|
||||
125
tests/isolation/test_runtime_helpers_stub_contract.py
Normal file
125
tests/isolation/test_runtime_helpers_stub_contract.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Generic runtime-helper stub contract tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
|
||||
from comfy.isolation import runtime_helpers
|
||||
from comfy_api.latest import io as latest_io
|
||||
from tests.isolation.stage_internal_probe_node import PROBE_NODE_NAME, staged_probe_node
|
||||
|
||||
|
||||
class _DummyExtension:
|
||||
def __init__(self, *, name: str, module_path: str):
|
||||
self.name = name
|
||||
self.module_path = module_path
|
||||
|
||||
async def execute_node(self, _node_name: str, **inputs):
|
||||
return {
|
||||
"__node_output__": True,
|
||||
"args": (inputs,),
|
||||
"ui": {"status": "ok"},
|
||||
"expand": False,
|
||||
"block_execution": False,
|
||||
}
|
||||
|
||||
|
||||
def _install_model_serialization_stub(monkeypatch):
|
||||
async def deserialize_from_isolation(payload, _extension):
|
||||
return payload
|
||||
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"pyisolate._internal.model_serialization",
|
||||
SimpleNamespace(
|
||||
serialize_for_isolation=lambda payload: payload,
|
||||
deserialize_from_isolation=deserialize_from_isolation,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_stub_sets_relative_python_module(monkeypatch):
|
||||
_install_model_serialization_stub(monkeypatch)
|
||||
monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(runtime_helpers, "_relieve_host_vram_pressure", lambda *args, **kwargs: None)
|
||||
|
||||
extension = _DummyExtension(name="internal_probe", module_path=os.getcwd())
|
||||
stub = cast(Any, runtime_helpers.build_stub_class(
|
||||
"ProbeNode",
|
||||
{
|
||||
"is_v3": True,
|
||||
"schema_v1": {},
|
||||
"input_types": {},
|
||||
},
|
||||
extension,
|
||||
{},
|
||||
logging.getLogger("test"),
|
||||
))
|
||||
|
||||
info = getattr(stub, "GET_NODE_INFO_V1")()
|
||||
assert info["python_module"] == "custom_nodes.internal_probe"
|
||||
|
||||
|
||||
def test_stub_ui_dispatch_roundtrip(monkeypatch):
|
||||
_install_model_serialization_stub(monkeypatch)
|
||||
monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(runtime_helpers, "_relieve_host_vram_pressure", lambda *args, **kwargs: None)
|
||||
|
||||
extension = _DummyExtension(name="internal_probe", module_path=os.getcwd())
|
||||
stub = runtime_helpers.build_stub_class(
|
||||
"ProbeNode",
|
||||
{
|
||||
"is_v3": True,
|
||||
"schema_v1": {"python_module": "custom_nodes.internal_probe"},
|
||||
"input_types": {},
|
||||
},
|
||||
extension,
|
||||
{},
|
||||
logging.getLogger("test"),
|
||||
)
|
||||
|
||||
result = asyncio.run(getattr(stub, "_pyisolate_execute")(SimpleNamespace(), token="value"))
|
||||
|
||||
assert isinstance(result, latest_io.NodeOutput)
|
||||
assert result.ui == {"status": "ok"}
|
||||
|
||||
|
||||
def test_stub_class_types_align_with_extension():
|
||||
extension = SimpleNamespace(name="internal_probe", module_path="/sandbox/probe")
|
||||
running_extensions = {"internal_probe": extension}
|
||||
|
||||
specs = [
|
||||
SimpleNamespace(module_path=Path("/sandbox/probe"), node_name="ProbeImage"),
|
||||
SimpleNamespace(module_path=Path("/sandbox/probe"), node_name="ProbeAudio"),
|
||||
SimpleNamespace(module_path=Path("/sandbox/other"), node_name="OtherNode"),
|
||||
]
|
||||
|
||||
class_types = runtime_helpers.get_class_types_for_extension(
|
||||
"internal_probe", running_extensions, specs
|
||||
)
|
||||
|
||||
assert class_types == {"ProbeImage", "ProbeAudio"}
|
||||
|
||||
|
||||
def test_probe_stage_requires_explicit_root():
|
||||
script = Path(__file__).resolve().parent / "stage_internal_probe_node.py"
|
||||
result = subprocess.run([sys.executable, str(script)], capture_output=True, text=True, check=False)
|
||||
|
||||
assert result.returncode != 0
|
||||
assert "--target-root" in result.stderr
|
||||
|
||||
|
||||
def test_probe_stage_cleans_up_context():
|
||||
with staged_probe_node() as module_path:
|
||||
staged_root = module_path.parents[1]
|
||||
assert module_path.name == PROBE_NODE_NAME
|
||||
assert staged_root.exists()
|
||||
|
||||
assert not staged_root.exists()
|
||||
53
tests/isolation/test_savedimages_serialization.py
Normal file
53
tests/isolation/test_savedimages_serialization.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import logging
|
||||
import socket
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
pyisolate_root = repo_root.parent / "pyisolate"
|
||||
if pyisolate_root.exists():
|
||||
sys.path.insert(0, str(pyisolate_root))
|
||||
|
||||
from comfy.isolation.adapter import ComfyUIAdapter
|
||||
from comfy_api.latest._io import FolderType
|
||||
from comfy_api.latest._ui import SavedImages, SavedResult
|
||||
from pyisolate._internal.rpc_transports import JSONSocketTransport
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
|
||||
|
||||
def test_savedimages_roundtrip(caplog):
|
||||
registry = SerializerRegistry.get_instance()
|
||||
registry.clear()
|
||||
ComfyUIAdapter().register_serializers(registry)
|
||||
|
||||
payload = SavedImages(
|
||||
results=[SavedResult("issue82.png", "slice2", FolderType.output)],
|
||||
is_animated=True,
|
||||
)
|
||||
|
||||
a, b = socket.socketpair()
|
||||
sender = JSONSocketTransport(a)
|
||||
receiver = JSONSocketTransport(b)
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING, logger="pyisolate._internal.rpc_transports"):
|
||||
sender.send({"ui": payload})
|
||||
result = receiver.recv()
|
||||
finally:
|
||||
sender.close()
|
||||
receiver.close()
|
||||
registry.clear()
|
||||
|
||||
ui = result["ui"]
|
||||
assert isinstance(ui, SavedImages)
|
||||
assert ui.is_animated is True
|
||||
assert len(ui.results) == 1
|
||||
assert isinstance(ui.results[0], SavedResult)
|
||||
assert ui.results[0].filename == "issue82.png"
|
||||
assert ui.results[0].subfolder == "slice2"
|
||||
assert ui.results[0].type == FolderType.output
|
||||
assert ui.as_dict() == {
|
||||
"images": [SavedResult("issue82.png", "slice2", FolderType.output)],
|
||||
"animated": (True,),
|
||||
}
|
||||
assert not any("GENERIC SERIALIZER USED" in record.message for record in caplog.records)
|
||||
assert not any("GENERIC DESERIALIZER USED" in record.message for record in caplog.records)
|
||||
368
tests/isolation/test_sealed_worker_contract_matrix.py
Normal file
368
tests/isolation/test_sealed_worker_contract_matrix.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""Generic sealed-worker loader contract matrix tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
COMFYUI_ROOT = Path(__file__).resolve().parents[2]
|
||||
TEST_WORKFLOW_ROOT = COMFYUI_ROOT / "tests" / "isolation" / "workflows"
|
||||
SEALED_WORKFLOW_CLASS_TYPES: dict[str, set[str]] = {
|
||||
"quick_6_uv_sealed_worker.json": {
|
||||
"EmptyLatentImage",
|
||||
"ProxyTestSealedWorker",
|
||||
"UVSealedBoltonsSlugify",
|
||||
"UVSealedLatentEcho",
|
||||
"UVSealedRuntimeProbe",
|
||||
},
|
||||
"isolation_7_uv_sealed_worker.json": {
|
||||
"EmptyLatentImage",
|
||||
"ProxyTestSealedWorker",
|
||||
"UVSealedBoltonsSlugify",
|
||||
"UVSealedLatentEcho",
|
||||
"UVSealedRuntimeProbe",
|
||||
},
|
||||
"quick_8_conda_sealed_worker.json": {
|
||||
"CondaSealedLatentEcho",
|
||||
"CondaSealedOpenWeatherDataset",
|
||||
"CondaSealedRuntimeProbe",
|
||||
"EmptyLatentImage",
|
||||
"ProxyTestCondaSealedWorker",
|
||||
},
|
||||
"isolation_9_conda_sealed_worker.json": {
|
||||
"CondaSealedLatentEcho",
|
||||
"CondaSealedOpenWeatherDataset",
|
||||
"CondaSealedRuntimeProbe",
|
||||
"EmptyLatentImage",
|
||||
"ProxyTestCondaSealedWorker",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _workflow_class_types(path: Path) -> set[str]:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
return {
|
||||
node["class_type"]
|
||||
for node in payload.values()
|
||||
if isinstance(node, dict) and "class_type" in node
|
||||
}
|
||||
|
||||
|
||||
def _make_manifest(
|
||||
*,
|
||||
package_manager: str = "uv",
|
||||
execution_model: str | None = None,
|
||||
can_isolate: bool = True,
|
||||
dependencies: list[str] | None = None,
|
||||
share_torch: bool = False,
|
||||
sealed_host_ro_paths: list[str] | None = None,
|
||||
) -> dict:
|
||||
isolation: dict[str, object] = {
|
||||
"can_isolate": can_isolate,
|
||||
}
|
||||
if package_manager != "uv":
|
||||
isolation["package_manager"] = package_manager
|
||||
if execution_model is not None:
|
||||
isolation["execution_model"] = execution_model
|
||||
if share_torch:
|
||||
isolation["share_torch"] = True
|
||||
if sealed_host_ro_paths is not None:
|
||||
isolation["sealed_host_ro_paths"] = sealed_host_ro_paths
|
||||
|
||||
if package_manager == "conda":
|
||||
isolation["conda_channels"] = ["conda-forge"]
|
||||
isolation["conda_dependencies"] = ["numpy"]
|
||||
|
||||
return {
|
||||
"project": {
|
||||
"name": "contract-extension",
|
||||
"dependencies": dependencies or ["numpy"],
|
||||
},
|
||||
"tool": {"comfy": {"isolation": isolation}},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manifest_file(tmp_path: Path) -> Path:
|
||||
path = tmp_path / "pyproject.toml"
|
||||
path.write_bytes(b"")
|
||||
return path
|
||||
|
||||
|
||||
def _loader_module(
|
||||
monkeypatch: pytest.MonkeyPatch, *, preload_extension_wrapper: bool
|
||||
):
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.ComfyNodeExtension = type("ComfyNodeExtension", (), {})
|
||||
|
||||
iso_mod = types.ModuleType("comfy.isolation")
|
||||
iso_mod.__path__ = [
|
||||
str(Path(__file__).resolve().parent.parent.parent / "comfy" / "isolation")
|
||||
]
|
||||
iso_mod.__package__ = "comfy.isolation"
|
||||
|
||||
manifest_loader = types.SimpleNamespace(
|
||||
is_cache_valid=lambda *args, **kwargs: False,
|
||||
load_from_cache=lambda *args, **kwargs: None,
|
||||
save_to_cache=lambda *args, **kwargs: None,
|
||||
)
|
||||
host_policy = types.SimpleNamespace(
|
||||
load_host_policy=lambda base_path: {
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
"sealed_worker_ro_import_paths": [],
|
||||
}
|
||||
)
|
||||
folder_paths = types.SimpleNamespace(base_path="/fake/comfyui")
|
||||
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation", iso_mod)
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.runtime_helpers", MagicMock())
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.manifest_loader", manifest_loader)
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.host_policy", host_policy)
|
||||
monkeypatch.setitem(sys.modules, "folder_paths", folder_paths)
|
||||
if preload_extension_wrapper:
|
||||
monkeypatch.setitem(sys.modules, "comfy.isolation.extension_wrapper", mock_wrapper)
|
||||
else:
|
||||
sys.modules.pop("comfy.isolation.extension_wrapper", None)
|
||||
sys.modules.pop("comfy.isolation.extension_loader", None)
|
||||
|
||||
module = importlib.import_module("comfy.isolation.extension_loader")
|
||||
try:
|
||||
yield module, mock_wrapper
|
||||
finally:
|
||||
sys.modules.pop("comfy.isolation.extension_loader", None)
|
||||
comfy_pkg = sys.modules.get("comfy")
|
||||
if comfy_pkg is not None and hasattr(comfy_pkg, "isolation"):
|
||||
delattr(comfy_pkg, "isolation")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader_module(monkeypatch: pytest.MonkeyPatch):
|
||||
yield from _loader_module(monkeypatch, preload_extension_wrapper=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sealed_loader_module(monkeypatch: pytest.MonkeyPatch):
|
||||
yield from _loader_module(monkeypatch, preload_extension_wrapper=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_loader(loader_module):
|
||||
module, mock_wrapper = loader_module
|
||||
mock_ext = AsyncMock()
|
||||
mock_ext.list_nodes = AsyncMock(return_value={})
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.load_extension = MagicMock(return_value=mock_ext)
|
||||
sealed_type = type("SealedNodeExtension", (), {})
|
||||
|
||||
with patch.object(module, "pyisolate") as mock_pi:
|
||||
mock_pi.ExtensionManager = MagicMock(return_value=mock_manager)
|
||||
mock_pi.SealedNodeExtension = sealed_type
|
||||
yield module, mock_pi, mock_manager, sealed_type, mock_wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sealed_mocked_loader(sealed_loader_module):
|
||||
module, mock_wrapper = sealed_loader_module
|
||||
mock_ext = AsyncMock()
|
||||
mock_ext.list_nodes = AsyncMock(return_value={})
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.load_extension = MagicMock(return_value=mock_ext)
|
||||
sealed_type = type("SealedNodeExtension", (), {})
|
||||
|
||||
with patch.object(module, "pyisolate") as mock_pi:
|
||||
mock_pi.ExtensionManager = MagicMock(return_value=mock_manager)
|
||||
mock_pi.SealedNodeExtension = sealed_type
|
||||
yield module, mock_pi, mock_manager, sealed_type, mock_wrapper
|
||||
|
||||
|
||||
async def _load_node(module, manifest: dict, manifest_path: Path, tmp_path: Path) -> dict:
|
||||
with patch("comfy.isolation.extension_loader.tomllib") as mock_tomllib:
|
||||
mock_tomllib.load.return_value = manifest
|
||||
await module.load_isolated_node(
|
||||
node_dir=tmp_path,
|
||||
manifest_path=manifest_path,
|
||||
logger=MagicMock(),
|
||||
build_stub_class=MagicMock(),
|
||||
venv_root=tmp_path / "venvs",
|
||||
extension_managers=[],
|
||||
)
|
||||
manager = module.pyisolate.ExtensionManager.return_value
|
||||
return manager.load_extension.call_args[0][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uv_host_coupled_default(mocked_loader, manifest_file: Path, tmp_path: Path):
|
||||
module, mock_pi, _mock_manager, sealed_type, _ = mocked_loader
|
||||
manifest = _make_manifest(package_manager="uv")
|
||||
|
||||
config = await _load_node(module, manifest, manifest_file, tmp_path)
|
||||
|
||||
extension_type = mock_pi.ExtensionManager.call_args[0][0]
|
||||
assert extension_type is not sealed_type
|
||||
assert "execution_model" not in config
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uv_sealed_worker_opt_in(
|
||||
sealed_mocked_loader, manifest_file: Path, tmp_path: Path
|
||||
):
|
||||
module, mock_pi, _mock_manager, sealed_type, _ = sealed_mocked_loader
|
||||
manifest = _make_manifest(package_manager="uv", execution_model="sealed_worker")
|
||||
|
||||
config = await _load_node(module, manifest, manifest_file, tmp_path)
|
||||
|
||||
extension_type = mock_pi.ExtensionManager.call_args[0][0]
|
||||
assert extension_type is sealed_type
|
||||
assert config["execution_model"] == "sealed_worker"
|
||||
assert "apis" not in config
|
||||
assert "comfy.isolation.extension_wrapper" not in sys.modules
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_defaults_to_sealed_worker(
|
||||
sealed_mocked_loader, manifest_file: Path, tmp_path: Path
|
||||
):
|
||||
module, mock_pi, _mock_manager, sealed_type, _ = sealed_mocked_loader
|
||||
manifest = _make_manifest(package_manager="conda")
|
||||
|
||||
config = await _load_node(module, manifest, manifest_file, tmp_path)
|
||||
|
||||
extension_type = mock_pi.ExtensionManager.call_args[0][0]
|
||||
assert extension_type is sealed_type
|
||||
assert config["execution_model"] == "sealed_worker"
|
||||
assert config["package_manager"] == "conda"
|
||||
assert "comfy.isolation.extension_wrapper" not in sys.modules
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_never_uses_comfy_extension_type(
|
||||
mocked_loader, manifest_file: Path, tmp_path: Path
|
||||
):
|
||||
module, mock_pi, _mock_manager, sealed_type, mock_wrapper = mocked_loader
|
||||
manifest = _make_manifest(package_manager="conda")
|
||||
|
||||
await _load_node(module, manifest, manifest_file, tmp_path)
|
||||
|
||||
extension_type = mock_pi.ExtensionManager.call_args[0][0]
|
||||
assert extension_type is sealed_type
|
||||
assert extension_type is not mock_wrapper.ComfyNodeExtension
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_forces_share_torch_false(mocked_loader, manifest_file: Path, tmp_path: Path):
|
||||
module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader
|
||||
manifest = _make_manifest(package_manager="conda", share_torch=True)
|
||||
|
||||
config = await _load_node(module, manifest, manifest_file, tmp_path)
|
||||
|
||||
assert config["share_torch"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_forces_share_cuda_ipc_false(
|
||||
mocked_loader, manifest_file: Path, tmp_path: Path
|
||||
):
|
||||
module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader
|
||||
manifest = _make_manifest(package_manager="conda", share_torch=True)
|
||||
|
||||
config = await _load_node(module, manifest, manifest_file, tmp_path)
|
||||
|
||||
assert config["share_cuda_ipc"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conda_sandbox_policy_applied(mocked_loader, manifest_file: Path, tmp_path: Path):
|
||||
module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader
|
||||
manifest = _make_manifest(package_manager="conda")
|
||||
|
||||
custom_policy = {
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": True,
|
||||
"writable_paths": ["/data/write"],
|
||||
"readonly_paths": ["/data/read"],
|
||||
}
|
||||
|
||||
with patch("platform.system", return_value="Linux"):
|
||||
with patch.object(module, "load_host_policy", return_value=custom_policy):
|
||||
config = await _load_node(module, manifest, manifest_file, tmp_path)
|
||||
|
||||
assert config["sandbox_mode"] == "required"
|
||||
assert config["sandbox"] == {
|
||||
"network": True,
|
||||
"writable_paths": ["/data/write"],
|
||||
"readonly_paths": ["/data/read"],
|
||||
}
|
||||
|
||||
|
||||
def test_sealed_worker_workflow_templates_present() -> None:
|
||||
missing = [
|
||||
filename
|
||||
for filename in SEALED_WORKFLOW_CLASS_TYPES
|
||||
if not (TEST_WORKFLOW_ROOT / filename).is_file()
|
||||
]
|
||||
assert not missing, f"missing sealed-worker workflow templates: {missing}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"workflow_name,expected_class_types",
|
||||
SEALED_WORKFLOW_CLASS_TYPES.items(),
|
||||
)
|
||||
def test_sealed_worker_workflow_class_type_contract(
|
||||
workflow_name: str, expected_class_types: set[str]
|
||||
) -> None:
|
||||
workflow_path = TEST_WORKFLOW_ROOT / workflow_name
|
||||
assert workflow_path.is_file(), f"workflow missing: {workflow_path}"
|
||||
|
||||
assert _workflow_class_types(workflow_path) == expected_class_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sealed_worker_host_policy_ro_import_matrix(
|
||||
mocked_loader, manifest_file: Path, tmp_path: Path
|
||||
):
|
||||
module, _mock_pi, _mock_manager, _sealed_type, _ = mocked_loader
|
||||
manifest = _make_manifest(package_manager="uv", execution_model="sealed_worker")
|
||||
|
||||
with patch.object(
|
||||
module,
|
||||
"load_host_policy",
|
||||
return_value={
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
"sealed_worker_ro_import_paths": [],
|
||||
},
|
||||
):
|
||||
default_config = await _load_node(module, manifest, manifest_file, tmp_path)
|
||||
|
||||
with patch.object(
|
||||
module,
|
||||
"load_host_policy",
|
||||
return_value={
|
||||
"sandbox_mode": "required",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
"sealed_worker_ro_import_paths": ["/home/johnj/ComfyUI"],
|
||||
},
|
||||
):
|
||||
opt_in_config = await _load_node(module, manifest, manifest_file, tmp_path)
|
||||
|
||||
assert default_config["execution_model"] == "sealed_worker"
|
||||
assert "sealed_host_ro_paths" not in default_config
|
||||
|
||||
assert opt_in_config["execution_model"] == "sealed_worker"
|
||||
assert opt_in_config["sealed_host_ro_paths"] == ["/home/johnj/ComfyUI"]
|
||||
assert "apis" not in opt_in_config
|
||||
44
tests/isolation/test_shared_model_proxy_contract.py
Normal file
44
tests/isolation/test_shared_model_proxy_contract.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
pyisolate_root = repo_root.parent / "pyisolate"
|
||||
if pyisolate_root.exists():
|
||||
sys.path.insert(0, str(pyisolate_root))
|
||||
|
||||
from comfy.isolation.adapter import ComfyUIAdapter
|
||||
from comfy.isolation.runtime_helpers import _wrap_remote_handles_as_host_proxies
|
||||
from pyisolate._internal.model_serialization import deserialize_from_isolation
|
||||
from pyisolate._internal.remote_handle import RemoteObjectHandle
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
|
||||
|
||||
def test_shared_model_ksampler_contract():
|
||||
registry = SerializerRegistry.get_instance()
|
||||
registry.clear()
|
||||
ComfyUIAdapter().register_serializers(registry)
|
||||
|
||||
handle = RemoteObjectHandle("model_0", "ModelPatcher")
|
||||
|
||||
class FakeExtension:
|
||||
async def call_remote_object_method(self, object_id, method_name, *args, **kwargs):
|
||||
assert object_id == "model_0"
|
||||
assert method_name == "get_model_object"
|
||||
assert args == ("latent_format",)
|
||||
assert kwargs == {}
|
||||
return "resolved:latent_format"
|
||||
|
||||
wrapped = (handle,)
|
||||
assert isinstance(wrapped, tuple)
|
||||
assert isinstance(wrapped[0], RemoteObjectHandle)
|
||||
|
||||
deserialized = asyncio.run(deserialize_from_isolation(wrapped))
|
||||
proxied = _wrap_remote_handles_as_host_proxies(deserialized, FakeExtension())
|
||||
model_for_host = proxied[0]
|
||||
|
||||
assert not isinstance(model_for_host, RemoteObjectHandle)
|
||||
assert hasattr(model_for_host, "get_model_object")
|
||||
assert model_for_host.get_model_object("latent_format") == "resolved:latent_format"
|
||||
|
||||
registry.clear()
|
||||
78
tests/isolation/test_singleton_proxy_boundary_matrix.py
Normal file
78
tests/isolation/test_singleton_proxy_boundary_matrix.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from tests.isolation.singleton_boundary_helpers import (
|
||||
capture_minimal_sealed_worker_imports,
|
||||
capture_sealed_singleton_imports,
|
||||
)
|
||||
|
||||
|
||||
def test_minimal_sealed_worker_forbidden_imports() -> None:
|
||||
payload = capture_minimal_sealed_worker_imports()
|
||||
|
||||
assert payload["mode"] == "minimal_sealed_worker"
|
||||
assert payload["runtime_probe_function"] == "inspect"
|
||||
assert payload["forbidden_matches"] == []
|
||||
|
||||
|
||||
def test_torch_share_subset_scope() -> None:
|
||||
minimal = capture_minimal_sealed_worker_imports()
|
||||
|
||||
allowed_torch_share_only = {
|
||||
"torch",
|
||||
"folder_paths",
|
||||
"comfy.utils",
|
||||
"comfy.model_management",
|
||||
"main",
|
||||
"comfy.isolation.extension_wrapper",
|
||||
}
|
||||
|
||||
assert minimal["forbidden_matches"] == []
|
||||
assert all(
|
||||
module_name not in minimal["modules"] for module_name in sorted(allowed_torch_share_only)
|
||||
)
|
||||
|
||||
|
||||
def test_capture_payload_is_json_serializable() -> None:
|
||||
payload = capture_minimal_sealed_worker_imports()
|
||||
|
||||
encoded = json.dumps(payload, sort_keys=True)
|
||||
|
||||
assert "\"minimal_sealed_worker\"" in encoded
|
||||
|
||||
|
||||
def test_folder_paths_child_safe() -> None:
|
||||
payload = capture_sealed_singleton_imports()
|
||||
|
||||
assert payload["mode"] == "sealed_singletons"
|
||||
assert payload["folder_path"] == "/sandbox/input/demo.png"
|
||||
assert payload["temp_dir"] == "/sandbox/temp"
|
||||
assert payload["models_dir"] == "/sandbox/models"
|
||||
assert payload["forbidden_matches"] == []
|
||||
|
||||
|
||||
def test_utils_child_safe() -> None:
|
||||
payload = capture_sealed_singleton_imports()
|
||||
|
||||
progress_calls = [
|
||||
call
|
||||
for call in payload["rpc_calls"]
|
||||
if call["object_id"] == "UtilsProxy" and call["method"] == "progress_bar_hook"
|
||||
]
|
||||
|
||||
assert progress_calls
|
||||
assert payload["forbidden_matches"] == []
|
||||
|
||||
|
||||
def test_progress_child_safe() -> None:
|
||||
payload = capture_sealed_singleton_imports()
|
||||
|
||||
progress_calls = [
|
||||
call
|
||||
for call in payload["rpc_calls"]
|
||||
if call["object_id"] == "ProgressProxy" and call["method"] == "rpc_set_progress"
|
||||
]
|
||||
|
||||
assert progress_calls
|
||||
assert payload["forbidden_matches"] == []
|
||||
129
tests/isolation/test_web_directory_handler.py
Normal file
129
tests/isolation/test_web_directory_handler.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Tests for WebDirectoryProxy host-side cache and aiohttp handler integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy.isolation.proxies.web_directory_proxy import (
|
||||
ALLOWED_EXTENSIONS,
|
||||
WebDirectoryCache,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_proxy() -> MagicMock:
|
||||
"""Create a mock WebDirectoryProxy RPC proxy."""
|
||||
proxy = MagicMock()
|
||||
proxy.list_web_files.return_value = [
|
||||
{"relative_path": "js/app.js", "content_type": "application/javascript"},
|
||||
{"relative_path": "js/utils.js", "content_type": "application/javascript"},
|
||||
{"relative_path": "index.html", "content_type": "text/html"},
|
||||
{"relative_path": "style.css", "content_type": "text/css"},
|
||||
]
|
||||
proxy.get_web_file.return_value = {
|
||||
"content": base64.b64encode(b"console.log('hello');").decode("ascii"),
|
||||
"content_type": "application/javascript",
|
||||
}
|
||||
return proxy
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def cache_with_proxy(mock_proxy: MagicMock) -> WebDirectoryCache:
|
||||
"""Create a WebDirectoryCache with a registered mock proxy."""
|
||||
cache = WebDirectoryCache()
|
||||
cache.register_proxy("test-extension", mock_proxy)
|
||||
return cache
|
||||
|
||||
|
||||
class TestExtensionsListing:
|
||||
"""AC-2: /extensions endpoint lists proxied JS files in URL format."""
|
||||
|
||||
def test_extensions_listing_produces_url_format_paths(
|
||||
self, cache_with_proxy: WebDirectoryCache
|
||||
) -> None:
|
||||
"""Simulate what server.py does: build /extensions/{name}/{path} URLs."""
|
||||
import urllib.parse
|
||||
|
||||
ext_name = "test-extension"
|
||||
urls = []
|
||||
for entry in cache_with_proxy.list_files(ext_name):
|
||||
if entry["relative_path"].endswith(".js"):
|
||||
urls.append(
|
||||
"/extensions/" + urllib.parse.quote(ext_name)
|
||||
+ "/" + entry["relative_path"]
|
||||
)
|
||||
|
||||
# Emit the actual URL list so it appears in test log output.
|
||||
sys.stdout.write(f"\n--- Proxied JS URLs ({len(urls)}) ---\n")
|
||||
for url in urls:
|
||||
sys.stdout.write(f" {url}\n")
|
||||
sys.stdout.write("--- End URLs ---\n")
|
||||
|
||||
# At least one proxied JS URL in /extensions/{name}/{path} format
|
||||
assert len(urls) >= 1, f"Expected >= 1 proxied JS URL, got {len(urls)}"
|
||||
assert "/extensions/test-extension/js/app.js" in urls, (
|
||||
f"Expected /extensions/test-extension/js/app.js in {urls}"
|
||||
)
|
||||
|
||||
|
||||
class TestCacheHit:
|
||||
"""AC-3: Cache populated on first request, reused on second."""
|
||||
|
||||
def test_cache_hit_single_rpc_call(
|
||||
self, cache_with_proxy: WebDirectoryCache, mock_proxy: MagicMock
|
||||
) -> None:
|
||||
# First call — RPC
|
||||
result1 = cache_with_proxy.get_file("test-extension", "js/app.js")
|
||||
assert result1 is not None
|
||||
assert result1["content"] == b"console.log('hello');"
|
||||
|
||||
# Second call — cache hit
|
||||
result2 = cache_with_proxy.get_file("test-extension", "js/app.js")
|
||||
assert result2 is not None
|
||||
assert result2["content"] == b"console.log('hello');"
|
||||
|
||||
# Proxy was called exactly once
|
||||
assert mock_proxy.get_web_file.call_count == 1
|
||||
|
||||
def test_cache_returns_none_for_unknown_extension(
|
||||
self, cache_with_proxy: WebDirectoryCache
|
||||
) -> None:
|
||||
result = cache_with_proxy.get_file("nonexistent", "js/app.js")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestForbiddenType:
|
||||
"""AC-4: Disallowed file types return HTTP 403 Forbidden."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"disallowed_path,expected_status",
|
||||
[
|
||||
("backdoor.py", 403),
|
||||
("malware.exe", 403),
|
||||
("exploit.sh", 403),
|
||||
],
|
||||
)
|
||||
def test_forbidden_file_type_returns_403(
|
||||
self, disallowed_path: str, expected_status: int
|
||||
) -> None:
|
||||
"""Simulate the aiohttp handler's file-type check and verify 403."""
|
||||
import os
|
||||
suffix = os.path.splitext(disallowed_path)[1].lower()
|
||||
|
||||
# This mirrors the handler logic in server.py:
|
||||
# if suffix not in ALLOWED_EXTENSIONS: return web.Response(status=403)
|
||||
if suffix not in ALLOWED_EXTENSIONS:
|
||||
status = 403
|
||||
else:
|
||||
status = 200
|
||||
|
||||
sys.stdout.write(
|
||||
f"\n--- HTTP status for {disallowed_path} (suffix={suffix}): {status} ---\n"
|
||||
)
|
||||
assert status == expected_status, (
|
||||
f"Expected HTTP {expected_status} for {disallowed_path}, got {status}"
|
||||
)
|
||||
130
tests/isolation/test_web_directory_proxy.py
Normal file
130
tests/isolation/test_web_directory_proxy.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Tests for WebDirectoryProxy — allow-list, traversal prevention, content serving."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryProxy
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def web_dir_with_mixed_files(tmp_path: Path) -> Path:
|
||||
"""Create a temp web directory with allowed and disallowed file types."""
|
||||
web = tmp_path / "web"
|
||||
js_dir = web / "js"
|
||||
js_dir.mkdir(parents=True)
|
||||
|
||||
# Allowed types
|
||||
(js_dir / "app.js").write_text("console.log('hello');")
|
||||
(web / "index.html").write_text("<html></html>")
|
||||
(web / "style.css").write_text("body { margin: 0; }")
|
||||
|
||||
# Disallowed types
|
||||
(web / "backdoor.py").write_text("import os; os.system('rm -rf /')")
|
||||
(web / "malware.exe").write_bytes(b"\x00" * 16)
|
||||
(web / "exploit.sh").write_text("#!/bin/bash\nrm -rf /")
|
||||
|
||||
return web
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def proxy_with_web_dir(web_dir_with_mixed_files: Path) -> WebDirectoryProxy:
|
||||
"""Create a WebDirectoryProxy with a registered test web directory."""
|
||||
proxy = WebDirectoryProxy()
|
||||
# Clear class-level state to avoid cross-test pollution
|
||||
WebDirectoryProxy._web_dirs = {}
|
||||
WebDirectoryProxy.register_web_dir("test-extension", str(web_dir_with_mixed_files))
|
||||
return proxy
|
||||
|
||||
|
||||
class TestAllowList:
|
||||
"""AC-2: list_web_files returns only allowed file types."""
|
||||
|
||||
def test_allowlist_only_safe_types(
|
||||
self, proxy_with_web_dir: WebDirectoryProxy
|
||||
) -> None:
|
||||
files = proxy_with_web_dir.list_web_files("test-extension")
|
||||
extensions = {Path(f["relative_path"]).suffix for f in files}
|
||||
|
||||
# Only .js, .html, .css should appear
|
||||
assert extensions == {".js", ".html", ".css"}
|
||||
|
||||
def test_allowlist_excludes_dangerous_types(
|
||||
self, proxy_with_web_dir: WebDirectoryProxy
|
||||
) -> None:
|
||||
files = proxy_with_web_dir.list_web_files("test-extension")
|
||||
paths = [f["relative_path"] for f in files]
|
||||
|
||||
assert not any(p.endswith(".py") for p in paths)
|
||||
assert not any(p.endswith(".exe") for p in paths)
|
||||
assert not any(p.endswith(".sh") for p in paths)
|
||||
|
||||
def test_allowlist_correct_count(
|
||||
self, proxy_with_web_dir: WebDirectoryProxy
|
||||
) -> None:
|
||||
files = proxy_with_web_dir.list_web_files("test-extension")
|
||||
# 3 allowed files: app.js, index.html, style.css
|
||||
assert len(files) == 3
|
||||
|
||||
def test_allowlist_unknown_extension_returns_empty(
|
||||
self, proxy_with_web_dir: WebDirectoryProxy
|
||||
) -> None:
|
||||
files = proxy_with_web_dir.list_web_files("nonexistent-extension")
|
||||
assert files == []
|
||||
|
||||
|
||||
class TestTraversal:
|
||||
"""AC-3: get_web_file rejects directory traversal attempts."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"malicious_path",
|
||||
[
|
||||
"../../../etc/passwd",
|
||||
"/etc/passwd",
|
||||
"../../__init__.py",
|
||||
],
|
||||
)
|
||||
def test_traversal_rejected(
|
||||
self, proxy_with_web_dir: WebDirectoryProxy, malicious_path: str
|
||||
) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
proxy_with_web_dir.get_web_file("test-extension", malicious_path)
|
||||
|
||||
|
||||
class TestContent:
|
||||
"""AC-4: get_web_file returns base64 content with correct MIME types."""
|
||||
|
||||
def test_content_js_mime_type(
|
||||
self, proxy_with_web_dir: WebDirectoryProxy
|
||||
) -> None:
|
||||
result = proxy_with_web_dir.get_web_file("test-extension", "js/app.js")
|
||||
assert result["content_type"] == "application/javascript"
|
||||
|
||||
def test_content_html_mime_type(
|
||||
self, proxy_with_web_dir: WebDirectoryProxy
|
||||
) -> None:
|
||||
result = proxy_with_web_dir.get_web_file("test-extension", "index.html")
|
||||
assert result["content_type"] == "text/html"
|
||||
|
||||
def test_content_css_mime_type(
|
||||
self, proxy_with_web_dir: WebDirectoryProxy
|
||||
) -> None:
|
||||
result = proxy_with_web_dir.get_web_file("test-extension", "style.css")
|
||||
assert result["content_type"] == "text/css"
|
||||
|
||||
def test_content_base64_roundtrip(
|
||||
self, proxy_with_web_dir: WebDirectoryProxy, web_dir_with_mixed_files: Path
|
||||
) -> None:
|
||||
result = proxy_with_web_dir.get_web_file("test-extension", "js/app.js")
|
||||
decoded = base64.b64decode(result["content"])
|
||||
source = (web_dir_with_mixed_files / "js" / "app.js").read_bytes()
|
||||
assert decoded == source
|
||||
|
||||
def test_content_disallowed_type_rejected(
|
||||
self, proxy_with_web_dir: WebDirectoryProxy
|
||||
) -> None:
|
||||
with pytest.raises(ValueError, match="Disallowed file type"):
|
||||
proxy_with_web_dir.get_web_file("test-extension", "backdoor.py")
|
||||
230
tests/isolation/uv_sealed_worker/__init__.py
Normal file
230
tests/isolation/uv_sealed_worker/__init__.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# pylint: disable=import-outside-toplevel,import-error
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _artifact_dir() -> Path | None:
|
||||
raw = os.environ.get("PYISOLATE_ARTIFACT_DIR")
|
||||
if not raw:
|
||||
return None
|
||||
path = Path(raw)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def _write_artifact(name: str, content: str) -> None:
|
||||
artifact_dir = _artifact_dir()
|
||||
if artifact_dir is None:
|
||||
return
|
||||
(artifact_dir / name).write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def _contains_tensor_marker(value: Any) -> bool:
|
||||
if isinstance(value, dict):
|
||||
if value.get("__type__") == "TensorValue":
|
||||
return True
|
||||
return any(_contains_tensor_marker(v) for v in value.values())
|
||||
if isinstance(value, (list, tuple)):
|
||||
return any(_contains_tensor_marker(v) for v in value)
|
||||
return False
|
||||
|
||||
|
||||
class InspectRuntimeNode:
|
||||
RETURN_TYPES = (
|
||||
"STRING",
|
||||
"STRING",
|
||||
"BOOLEAN",
|
||||
"BOOLEAN",
|
||||
"STRING",
|
||||
"STRING",
|
||||
"BOOLEAN",
|
||||
)
|
||||
RETURN_NAMES = (
|
||||
"path_dump",
|
||||
"boltons_origin",
|
||||
"saw_comfy_root",
|
||||
"imported_comfy_wrapper",
|
||||
"comfy_module_dump",
|
||||
"report",
|
||||
"saw_user_site",
|
||||
)
|
||||
FUNCTION = "inspect"
|
||||
CATEGORY = "PyIsolated/SealedWorker"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
|
||||
return {"required": {}}
|
||||
|
||||
def inspect(self) -> tuple[str, str, bool, bool, str, str, bool]:
|
||||
import boltons
|
||||
|
||||
path_dump = "\n".join(sys.path)
|
||||
comfy_root = "/home/johnj/ComfyUI"
|
||||
saw_comfy_root = any(
|
||||
entry == comfy_root
|
||||
or entry.startswith(f"{comfy_root}/comfy")
|
||||
or entry.startswith(f"{comfy_root}/.venv")
|
||||
for entry in sys.path
|
||||
)
|
||||
imported_comfy_wrapper = "comfy.isolation.extension_wrapper" in sys.modules
|
||||
comfy_module_dump = "\n".join(
|
||||
sorted(name for name in sys.modules if name.startswith("comfy"))
|
||||
)
|
||||
saw_user_site = any("/.local/lib/" in entry for entry in sys.path)
|
||||
boltons_origin = getattr(boltons, "__file__", "<missing>")
|
||||
|
||||
report_lines = [
|
||||
"UV sealed worker runtime probe",
|
||||
f"boltons_origin={boltons_origin}",
|
||||
f"saw_comfy_root={saw_comfy_root}",
|
||||
f"imported_comfy_wrapper={imported_comfy_wrapper}",
|
||||
f"saw_user_site={saw_user_site}",
|
||||
]
|
||||
report = "\n".join(report_lines)
|
||||
|
||||
_write_artifact("child_bootstrap_paths.txt", path_dump)
|
||||
_write_artifact("child_import_trace.txt", comfy_module_dump)
|
||||
_write_artifact("child_dependency_dump.txt", boltons_origin)
|
||||
logger.warning("][ UV sealed runtime probe executed")
|
||||
logger.warning("][ boltons origin: %s", boltons_origin)
|
||||
|
||||
return (
|
||||
path_dump,
|
||||
boltons_origin,
|
||||
saw_comfy_root,
|
||||
imported_comfy_wrapper,
|
||||
comfy_module_dump,
|
||||
report,
|
||||
saw_user_site,
|
||||
)
|
||||
|
||||
|
||||
class BoltonsSlugifyNode:
|
||||
RETURN_TYPES = ("STRING", "STRING")
|
||||
RETURN_NAMES = ("slug", "boltons_origin")
|
||||
FUNCTION = "slugify_text"
|
||||
CATEGORY = "PyIsolated/SealedWorker"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
|
||||
return {"required": {"text": ("STRING", {"default": "Sealed Worker Rocks"})}}
|
||||
|
||||
def slugify_text(self, text: str) -> tuple[str, str]:
|
||||
import boltons
|
||||
from boltons.strutils import slugify
|
||||
|
||||
slug = slugify(text)
|
||||
origin = getattr(boltons, "__file__", "<missing>")
|
||||
logger.warning("][ boltons slugify: %r -> %r", text, slug)
|
||||
return slug, origin
|
||||
|
||||
|
||||
class FilesystemBarrierNode:
|
||||
RETURN_TYPES = ("STRING", "BOOLEAN", "BOOLEAN", "BOOLEAN")
|
||||
RETURN_NAMES = (
|
||||
"report",
|
||||
"outside_blocked",
|
||||
"module_mutation_blocked",
|
||||
"artifact_write_ok",
|
||||
)
|
||||
FUNCTION = "probe"
|
||||
CATEGORY = "PyIsolated/SealedWorker"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
|
||||
return {"required": {}}
|
||||
|
||||
def probe(self) -> tuple[str, bool, bool, bool]:
|
||||
artifact_dir = _artifact_dir()
|
||||
artifact_write_ok = False
|
||||
if artifact_dir is not None:
|
||||
probe_path = artifact_dir / "filesystem_barrier_probe.txt"
|
||||
probe_path.write_text("artifact write ok\n", encoding="utf-8")
|
||||
artifact_write_ok = probe_path.exists()
|
||||
|
||||
module_target = Path(__file__).with_name(
|
||||
"mutated_from_child_should_not_exist.txt"
|
||||
)
|
||||
module_mutation_blocked = False
|
||||
try:
|
||||
module_target.write_text("mutation should fail\n", encoding="utf-8")
|
||||
except Exception:
|
||||
module_mutation_blocked = True
|
||||
else:
|
||||
module_target.unlink(missing_ok=True)
|
||||
|
||||
outside_target = Path("/home/johnj/mysolate/.uv_sealed_worker_escape_probe")
|
||||
outside_blocked = False
|
||||
try:
|
||||
outside_target.write_text("escape should fail\n", encoding="utf-8")
|
||||
except Exception:
|
||||
outside_blocked = True
|
||||
else:
|
||||
outside_target.unlink(missing_ok=True)
|
||||
|
||||
report_lines = [
|
||||
"UV sealed worker filesystem barrier probe",
|
||||
f"artifact_write_ok={artifact_write_ok}",
|
||||
f"module_mutation_blocked={module_mutation_blocked}",
|
||||
f"outside_blocked={outside_blocked}",
|
||||
]
|
||||
report = "\n".join(report_lines)
|
||||
_write_artifact("filesystem_barrier_report.txt", report)
|
||||
logger.warning("][ filesystem barrier probe executed")
|
||||
return report, outside_blocked, module_mutation_blocked, artifact_write_ok
|
||||
|
||||
|
||||
class EchoTensorNode:
|
||||
RETURN_TYPES = ("TENSOR", "BOOLEAN")
|
||||
RETURN_NAMES = ("tensor", "saw_json_tensor")
|
||||
FUNCTION = "echo"
|
||||
CATEGORY = "PyIsolated/SealedWorker"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
|
||||
return {"required": {"tensor": ("TENSOR",)}}
|
||||
|
||||
def echo(self, tensor: Any) -> tuple[Any, bool]:
|
||||
saw_json_tensor = _contains_tensor_marker(tensor)
|
||||
logger.warning("][ tensor echo json_marker=%s", saw_json_tensor)
|
||||
return tensor, saw_json_tensor
|
||||
|
||||
|
||||
class EchoLatentNode:
|
||||
RETURN_TYPES = ("LATENT", "BOOLEAN")
|
||||
RETURN_NAMES = ("latent", "saw_json_tensor")
|
||||
FUNCTION = "echo_latent"
|
||||
CATEGORY = "PyIsolated/SealedWorker"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802
|
||||
return {"required": {"latent": ("LATENT",)}}
|
||||
|
||||
def echo_latent(self, latent: Any) -> tuple[Any, bool]:
|
||||
saw_json_tensor = _contains_tensor_marker(latent)
|
||||
logger.warning("][ latent echo json_marker=%s", saw_json_tensor)
|
||||
return latent, saw_json_tensor
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"UVSealedRuntimeProbe": InspectRuntimeNode,
|
||||
"UVSealedBoltonsSlugify": BoltonsSlugifyNode,
|
||||
"UVSealedFilesystemBarrier": FilesystemBarrierNode,
|
||||
"UVSealedTensorEcho": EchoTensorNode,
|
||||
"UVSealedLatentEcho": EchoLatentNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"UVSealedRuntimeProbe": "UV Sealed Runtime Probe",
|
||||
"UVSealedBoltonsSlugify": "UV Sealed Boltons Slugify",
|
||||
"UVSealedFilesystemBarrier": "UV Sealed Filesystem Barrier",
|
||||
"UVSealedTensorEcho": "UV Sealed Tensor Echo",
|
||||
"UVSealedLatentEcho": "UV Sealed Latent Echo",
|
||||
}
|
||||
11
tests/isolation/uv_sealed_worker/pyproject.toml
Normal file
11
tests/isolation/uv_sealed_worker/pyproject.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
[project]
|
||||
name = "comfyui-toolkit-uv-sealed-worker"
|
||||
version = "0.1.0"
|
||||
dependencies = ["boltons"]
|
||||
|
||||
[tool.comfy.isolation]
|
||||
can_isolate = true
|
||||
share_torch = false
|
||||
package_manager = "uv"
|
||||
execution_model = "sealed_worker"
|
||||
standalone = true
|
||||
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"1": {
|
||||
"class_type": "InternalIsolationProbeImage",
|
||||
"inputs": {}
|
||||
},
|
||||
"2": {
|
||||
"class_type": "InternalIsolationProbeAudio",
|
||||
"inputs": {}
|
||||
}
|
||||
}
|
||||
6
tests/isolation/workflows/internal_probe_ui3d.json
Normal file
6
tests/isolation/workflows/internal_probe_ui3d.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"1": {
|
||||
"class_type": "InternalIsolationProbeUI3D",
|
||||
"inputs": {}
|
||||
}
|
||||
}
|
||||
22
tests/isolation/workflows/isolation_7_uv_sealed_worker.json
Normal file
22
tests/isolation/workflows/isolation_7_uv_sealed_worker.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"1": {
|
||||
"class_type": "EmptyLatentImage",
|
||||
"inputs": {}
|
||||
},
|
||||
"2": {
|
||||
"class_type": "ProxyTestSealedWorker",
|
||||
"inputs": {}
|
||||
},
|
||||
"3": {
|
||||
"class_type": "UVSealedBoltonsSlugify",
|
||||
"inputs": {}
|
||||
},
|
||||
"4": {
|
||||
"class_type": "UVSealedLatentEcho",
|
||||
"inputs": {}
|
||||
},
|
||||
"5": {
|
||||
"class_type": "UVSealedRuntimeProbe",
|
||||
"inputs": {}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"1": {
|
||||
"class_type": "CondaSealedLatentEcho",
|
||||
"inputs": {}
|
||||
},
|
||||
"2": {
|
||||
"class_type": "CondaSealedOpenWeatherDataset",
|
||||
"inputs": {}
|
||||
},
|
||||
"3": {
|
||||
"class_type": "CondaSealedRuntimeProbe",
|
||||
"inputs": {}
|
||||
},
|
||||
"4": {
|
||||
"class_type": "EmptyLatentImage",
|
||||
"inputs": {}
|
||||
},
|
||||
"5": {
|
||||
"class_type": "ProxyTestCondaSealedWorker",
|
||||
"inputs": {}
|
||||
}
|
||||
}
|
||||
22
tests/isolation/workflows/quick_6_uv_sealed_worker.json
Normal file
22
tests/isolation/workflows/quick_6_uv_sealed_worker.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"1": {
|
||||
"class_type": "EmptyLatentImage",
|
||||
"inputs": {}
|
||||
},
|
||||
"2": {
|
||||
"class_type": "ProxyTestSealedWorker",
|
||||
"inputs": {}
|
||||
},
|
||||
"3": {
|
||||
"class_type": "UVSealedBoltonsSlugify",
|
||||
"inputs": {}
|
||||
},
|
||||
"4": {
|
||||
"class_type": "UVSealedLatentEcho",
|
||||
"inputs": {}
|
||||
},
|
||||
"5": {
|
||||
"class_type": "UVSealedRuntimeProbe",
|
||||
"inputs": {}
|
||||
}
|
||||
}
|
||||
22
tests/isolation/workflows/quick_8_conda_sealed_worker.json
Normal file
22
tests/isolation/workflows/quick_8_conda_sealed_worker.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"1": {
|
||||
"class_type": "CondaSealedLatentEcho",
|
||||
"inputs": {}
|
||||
},
|
||||
"2": {
|
||||
"class_type": "CondaSealedOpenWeatherDataset",
|
||||
"inputs": {}
|
||||
},
|
||||
"3": {
|
||||
"class_type": "CondaSealedRuntimeProbe",
|
||||
"inputs": {}
|
||||
},
|
||||
"4": {
|
||||
"class_type": "EmptyLatentImage",
|
||||
"inputs": {}
|
||||
},
|
||||
"5": {
|
||||
"class_type": "ProxyTestCondaSealedWorker",
|
||||
"inputs": {}
|
||||
}
|
||||
}
|
||||
122
tests/test_adapter.py
Normal file
122
tests/test_adapter.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
pyisolate_root = repo_root.parent / "pyisolate"
|
||||
if pyisolate_root.exists():
|
||||
sys.path.insert(0, str(pyisolate_root))
|
||||
|
||||
from comfy.isolation.adapter import ComfyUIAdapter
|
||||
from pyisolate._internal.sandbox import build_bwrap_command
|
||||
from pyisolate._internal.sandbox_detect import RestrictionModel
|
||||
from pyisolate._internal.serialization_registry import SerializerRegistry
|
||||
|
||||
|
||||
def test_identifier():
|
||||
adapter = ComfyUIAdapter()
|
||||
assert adapter.identifier == "comfyui"
|
||||
|
||||
|
||||
def test_get_path_config_valid():
|
||||
adapter = ComfyUIAdapter()
|
||||
path = os.path.join("/opt", "ComfyUI", "custom_nodes", "demo")
|
||||
cfg = adapter.get_path_config(path)
|
||||
assert cfg is not None
|
||||
assert cfg["preferred_root"].endswith("ComfyUI")
|
||||
assert "custom_nodes" in cfg["additional_paths"][0]
|
||||
|
||||
|
||||
def test_get_path_config_invalid():
|
||||
adapter = ComfyUIAdapter()
|
||||
assert adapter.get_path_config("/random/path") is None
|
||||
|
||||
|
||||
def test_provide_rpc_services():
|
||||
adapter = ComfyUIAdapter()
|
||||
services = adapter.provide_rpc_services()
|
||||
names = {s.__name__ for s in services}
|
||||
assert "PromptServerService" in names
|
||||
assert "FolderPathsProxy" in names
|
||||
|
||||
|
||||
def test_register_serializers():
|
||||
adapter = ComfyUIAdapter()
|
||||
registry = SerializerRegistry.get_instance()
|
||||
registry.clear()
|
||||
|
||||
adapter.register_serializers(registry)
|
||||
assert registry.has_handler("ModelPatcher")
|
||||
assert registry.has_handler("CLIP")
|
||||
assert registry.has_handler("VAE")
|
||||
|
||||
registry.clear()
|
||||
|
||||
|
||||
def test_child_temp_directory_fence_uses_private_tmp(tmp_path):
|
||||
adapter = ComfyUIAdapter()
|
||||
child_script = textwrap.dedent(
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
child_temp = Path("/tmp/comfyui_temp")
|
||||
child_temp.mkdir(parents=True, exist_ok=True)
|
||||
scratch = child_temp / "child_only.txt"
|
||||
scratch.write_text("child-only", encoding="utf-8")
|
||||
print(f"CHILD_TEMP={child_temp}")
|
||||
print(f"CHILD_FILE={scratch}")
|
||||
"""
|
||||
)
|
||||
fake_folder_paths = types.SimpleNamespace(
|
||||
temp_directory="/host/tmp/should_not_survive",
|
||||
folder_names_and_paths={},
|
||||
extension_mimetypes_cache={},
|
||||
filename_list_cache={},
|
||||
)
|
||||
|
||||
class FolderPathsProxy:
|
||||
def get_temp_directory(self):
|
||||
return "/host/tmp/should_not_survive"
|
||||
|
||||
original_folder_paths = sys.modules.get("folder_paths")
|
||||
sys.modules["folder_paths"] = fake_folder_paths
|
||||
try:
|
||||
os.environ["PYISOLATE_CHILD"] = "1"
|
||||
adapter.handle_api_registration(FolderPathsProxy, rpc=None)
|
||||
finally:
|
||||
os.environ.pop("PYISOLATE_CHILD", None)
|
||||
if original_folder_paths is not None:
|
||||
sys.modules["folder_paths"] = original_folder_paths
|
||||
else:
|
||||
sys.modules.pop("folder_paths", None)
|
||||
|
||||
assert fake_folder_paths.temp_directory == "/tmp/comfyui_temp"
|
||||
|
||||
host_child_file = Path("/tmp/comfyui_temp/child_only.txt")
|
||||
if host_child_file.exists():
|
||||
host_child_file.unlink()
|
||||
|
||||
cmd = build_bwrap_command(
|
||||
python_exe=sys.executable,
|
||||
module_path=str(repo_root / "custom_nodes" / "ComfyUI-IsolationToolkit"),
|
||||
venv_path=str(repo_root / ".venv"),
|
||||
uds_address=str(tmp_path / "adapter.sock"),
|
||||
allow_gpu=False,
|
||||
restriction_model=RestrictionModel.NONE,
|
||||
sandbox_config={"writable_paths": ["/dev/shm"], "readonly_paths": [], "network": False},
|
||||
adapter=adapter,
|
||||
)
|
||||
assert "--tmpfs" in cmd and "/tmp" in cmd
|
||||
assert ["--bind", "/tmp", "/tmp"] not in [cmd[i : i + 3] for i in range(len(cmd) - 2)]
|
||||
|
||||
command_tail = cmd[-3:]
|
||||
assert command_tail[1:] == ["-m", "pyisolate._internal.uds_client"]
|
||||
cmd = cmd[:-3] + [sys.executable, "-c", child_script]
|
||||
|
||||
completed = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
|
||||
assert "CHILD_TEMP=/tmp/comfyui_temp" in completed.stdout
|
||||
assert not host_child_file.exists(), "Child temp file leaked into host /tmp"
|
||||
Reference in New Issue
Block a user