mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-28 02:11:31 +00:00
feat(isolation): DynamicVRAM compatibility for process isolation
DynamicVRAM's on-demand model loading/offloading conflicted with process isolation in three ways: RPC tensor transport stalls from mid-call GPU offload, race conditions between model lifecycle and active RPC operations, and false positive memory leak detection from changed finalizer patterns. - Marshal CUDA tensors to CPU before RPC transport for dynamic models - Add operation state tracking + quiescence waits at workflow boundaries - Distinguish proxy reference release from actual leaks in cleanup_models_gc - Fix init order: DynamicVRAM must initialize before isolation proxies - Add RPC timeouts to prevent indefinite hangs on model unavailability - Prevent proxy-of-proxy chains from DynamicVRAM model reload cycles - Add torch.device/torch.dtype serializers for new DynamicVRAM RPC paths - Guard isolation overhead so non-isolated workflows are unaffected - Migrate env var to PYISOLATE_CHILD
This commit is contained in:
@@ -26,6 +26,7 @@ PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
|
||||
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
|
||||
|
||||
|
||||
def initialize_proxies() -> None:
|
||||
@@ -168,6 +169,11 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]:
|
||||
|
||||
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
"""Evict running extensions not needed for current execution."""
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:notify_graph_wait_idle",
|
||||
)
|
||||
|
||||
async def _stop_extension(
|
||||
ext_name: str, extension: "ComfyNodeExtension", reason: str
|
||||
@@ -182,22 +188,33 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
|
||||
|
||||
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
|
||||
isolated_class_types_in_graph = needed_class_types.intersection(
|
||||
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
|
||||
)
|
||||
graph_uses_isolation = bool(isolated_class_types_in_graph)
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_start running=%d needed=%d",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
len(needed_class_types),
|
||||
)
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
if graph_uses_isolation:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
ext_class_types = _get_class_types_for_extension(ext_name)
|
||||
|
||||
# If NONE of this extension's nodes are in the execution graph → evict
|
||||
if not ext_class_types.intersection(needed_class_types):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
"isolated custom_node not in execution graph, evicting",
|
||||
)
|
||||
# If NONE of this extension's nodes are in the execution graph -> evict.
|
||||
if not ext_class_types.intersection(needed_class_types):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
extension,
|
||||
"isolated custom_node not in execution graph, evicting",
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph",
|
||||
LOG_PREFIX,
|
||||
len(_RUNNING_EXTENSIONS),
|
||||
)
|
||||
|
||||
# Isolated child processes add steady VRAM pressure; reclaim host-side models
|
||||
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
|
||||
@@ -211,7 +228,7 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
|
||||
)
|
||||
free_before = model_management.get_free_memory(device)
|
||||
if free_before < required and _RUNNING_EXTENSIONS:
|
||||
if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation:
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
await _stop_extension(
|
||||
ext_name,
|
||||
@@ -237,6 +254,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
|
||||
|
||||
|
||||
async def flush_running_extensions_transport_state() -> int:
|
||||
await wait_for_model_patcher_quiescence(
|
||||
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
fail_loud=True,
|
||||
marker="ISO:flush_transport_wait_idle",
|
||||
)
|
||||
total_flushed = 0
|
||||
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
|
||||
flush_fn = getattr(extension, "flush_transport_state", None)
|
||||
@@ -263,6 +285,50 @@ async def flush_running_extensions_transport_state() -> int:
|
||||
return total_flushed
|
||||
|
||||
|
||||
async def wait_for_model_patcher_quiescence(
|
||||
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
|
||||
*,
|
||||
fail_loud: bool = False,
|
||||
marker: str = "ISO:wait_model_patcher_idle",
|
||||
) -> bool:
|
||||
try:
|
||||
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
|
||||
|
||||
registry = ModelPatcherRegistry()
|
||||
start = time.perf_counter()
|
||||
idle = await registry.wait_all_idle(timeout_ms)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
if idle:
|
||||
logger.debug(
|
||||
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
)
|
||||
return True
|
||||
|
||||
states = await registry.get_all_operation_states()
|
||||
logger.error(
|
||||
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
|
||||
LOG_PREFIX,
|
||||
marker,
|
||||
timeout_ms,
|
||||
elapsed_ms,
|
||||
states,
|
||||
)
|
||||
if fail_loud:
|
||||
raise TimeoutError(
|
||||
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
if fail_loud:
|
||||
raise
|
||||
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_claimed_paths() -> Set[Path]:
|
||||
return _CLAIMED_PATHS
|
||||
|
||||
@@ -320,6 +386,7 @@ __all__ = [
|
||||
"await_isolation_loading",
|
||||
"notify_execution_graph",
|
||||
"flush_running_extensions_transport_state",
|
||||
"wait_for_model_patcher_quiescence",
|
||||
"get_claimed_paths",
|
||||
"update_rpc_event_loops",
|
||||
"IsolatedNodeSpec",
|
||||
|
||||
Reference in New Issue
Block a user