feat(isolation): DynamicVRAM compatibility for process isolation

DynamicVRAM's on-demand model loading/offloading conflicted with  process isolation in three ways: RPC tensor transport stalls from mid-call GPU offload, race conditions between model lifecycle and active RPC operations, and false positive memory leak detection from changed finalizer patterns.

- Marshal CUDA tensors to CPU before RPC transport for dynamic models
- Add operation state tracking + quiescence waits at workflow boundaries
- Distinguish proxy reference release from actual leaks in cleanup_models_gc
- Fix init order: DynamicVRAM must initialize before isolation proxies
- Add RPC timeouts to prevent indefinite hangs on model unavailability
- Prevent proxy-of-proxy chains from DynamicVRAM model reload cycles
- Add torch.device/torch.dtype serializers for new DynamicVRAM RPC paths
- Guard isolation overhead so non-isolated workflows are unaffected
- Migrate env var to PYISOLATE_CHILD
This commit is contained in:
John Pollock
2026-03-04 23:48:02 -06:00
parent a0f8784e9f
commit 9250191c65
38 changed files with 94595 additions and 307 deletions

View File

@@ -26,6 +26,7 @@ PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
logger = logging.getLogger(__name__)
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
def initialize_proxies() -> None:
@@ -168,6 +169,11 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]:
async def notify_execution_graph(needed_class_types: Set[str]) -> None:
"""Evict running extensions not needed for current execution."""
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:notify_graph_wait_idle",
)
async def _stop_extension(
ext_name: str, extension: "ComfyNodeExtension", reason: str
@@ -182,22 +188,33 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
isolated_class_types_in_graph = needed_class_types.intersection(
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
)
graph_uses_isolation = bool(isolated_class_types_in_graph)
logger.debug(
"%s ISO:notify_graph_start running=%d needed=%d",
LOG_PREFIX,
len(_RUNNING_EXTENSIONS),
len(needed_class_types),
)
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
ext_class_types = _get_class_types_for_extension(ext_name)
if graph_uses_isolation:
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
ext_class_types = _get_class_types_for_extension(ext_name)
# If NONE of this extension's nodes are in the execution graph evict
if not ext_class_types.intersection(needed_class_types):
await _stop_extension(
ext_name,
extension,
"isolated custom_node not in execution graph, evicting",
)
# If NONE of this extension's nodes are in the execution graph -> evict.
if not ext_class_types.intersection(needed_class_types):
await _stop_extension(
ext_name,
extension,
"isolated custom_node not in execution graph, evicting",
)
else:
logger.debug(
"%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph",
LOG_PREFIX,
len(_RUNNING_EXTENSIONS),
)
# Isolated child processes add steady VRAM pressure; reclaim host-side models
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
@@ -211,7 +228,7 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
)
free_before = model_management.get_free_memory(device)
if free_before < required and _RUNNING_EXTENSIONS:
if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation:
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
await _stop_extension(
ext_name,
@@ -237,6 +254,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
async def flush_running_extensions_transport_state() -> int:
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:flush_transport_wait_idle",
)
total_flushed = 0
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
flush_fn = getattr(extension, "flush_transport_state", None)
@@ -263,6 +285,50 @@ async def flush_running_extensions_transport_state() -> int:
return total_flushed
async def wait_for_model_patcher_quiescence(
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
*,
fail_loud: bool = False,
marker: str = "ISO:wait_model_patcher_idle",
) -> bool:
try:
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
registry = ModelPatcherRegistry()
start = time.perf_counter()
idle = await registry.wait_all_idle(timeout_ms)
elapsed_ms = (time.perf_counter() - start) * 1000.0
if idle:
logger.debug(
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
LOG_PREFIX,
marker,
timeout_ms,
elapsed_ms,
)
return True
states = await registry.get_all_operation_states()
logger.error(
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
LOG_PREFIX,
marker,
timeout_ms,
elapsed_ms,
states,
)
if fail_loud:
raise TimeoutError(
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
)
return False
except Exception:
if fail_loud:
raise
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
return False
def get_claimed_paths() -> Set[Path]:
return _CLAIMED_PATHS
@@ -320,6 +386,7 @@ __all__ = [
"await_isolation_loading",
"notify_execution_graph",
"flush_running_extensions_transport_state",
"wait_for_model_patcher_quiescence",
"get_claimed_paths",
"update_rpc_event_loops",
"IsolatedNodeSpec",

View File

@@ -83,6 +83,33 @@ class ComfyUIAdapter(IsolationAdapter):
logging.getLogger(pkg_name).setLevel(logging.ERROR)
def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
import torch
def serialize_device(obj: Any) -> Dict[str, Any]:
return {"__type__": "device", "device_str": str(obj)}
def deserialize_device(data: Dict[str, Any]) -> Any:
return torch.device(data["device_str"])
registry.register("device", serialize_device, deserialize_device)
_VALID_DTYPES = {
"float16", "float32", "float64", "bfloat16",
"int8", "int16", "int32", "int64",
"uint8", "bool",
}
def serialize_dtype(obj: Any) -> Dict[str, Any]:
return {"__type__": "dtype", "dtype_str": str(obj)}
def deserialize_dtype(data: Dict[str, Any]) -> Any:
dtype_name = data["dtype_str"].replace("torch.", "")
if dtype_name not in _VALID_DTYPES:
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
return getattr(torch, dtype_name)
registry.register("dtype", serialize_dtype, deserialize_dtype)
def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1":
@@ -193,6 +220,10 @@ class ComfyUIAdapter(IsolationAdapter):
f"ModelSampling in child lacks _instance_id: "
f"{type(obj).__module__}.{type(obj).__name__}"
)
# Host-side pass-through for proxies: do not re-register a proxy as a
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict
ms_id = ModelSamplingRegistry().register(obj)
return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
@@ -211,22 +242,21 @@ class ComfyUIAdapter(IsolationAdapter):
else:
return ModelSamplingRegistry()._get_instance(data["ms_id"])
# Register ModelSampling type and proxy
registry.register(
"ModelSamplingDiscrete",
serialize_model_sampling,
deserialize_model_sampling,
)
registry.register(
"ModelSamplingContinuousEDM",
serialize_model_sampling,
deserialize_model_sampling,
)
registry.register(
"ModelSamplingContinuousV",
serialize_model_sampling,
deserialize_model_sampling,
)
# Register all ModelSampling* and StableCascadeSampling classes dynamically
import comfy.model_sampling
for ms_cls in vars(comfy.model_sampling).values():
if not isinstance(ms_cls, type):
continue
if not issubclass(ms_cls, torch.nn.Module):
continue
if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
continue
registry.register(
ms_cls.__name__,
serialize_model_sampling,
deserialize_model_sampling,
)
registry.register(
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
)

View File

@@ -389,7 +389,8 @@ class ComfyNodeExtension(ExtensionBase):
getattr(self, "name", "?"),
node_name,
)
return self._wrap_unpicklable_objects(result)
wrapped = self._wrap_unpicklable_objects(result)
return wrapped
if not isinstance(result, tuple):
result = (result,)
@@ -400,7 +401,8 @@ class ComfyNodeExtension(ExtensionBase):
node_name,
len(result),
)
return self._wrap_unpicklable_objects(result)
wrapped = self._wrap_unpicklable_objects(result)
return wrapped
async def flush_transport_state(self) -> int:
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1":
@@ -443,7 +445,10 @@ class ComfyNodeExtension(ExtensionBase):
if isinstance(data, (str, int, float, bool, type(None))):
return data
if isinstance(data, torch.Tensor):
return data.detach() if data.requires_grad else data
tensor = data.detach() if data.requires_grad else data
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
return tensor.cpu()
return tensor
# Special-case clip vision outputs: preserve attribute access by packing fields
if hasattr(data, "penultimate_hidden_states") or hasattr(

View File

@@ -338,6 +338,34 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def apply_model(self, *args, **kwargs) -> Any:
import torch
def _preferred_device() -> Any:
for value in args:
if isinstance(value, torch.Tensor):
return value.device
for value in kwargs.values():
if isinstance(value, torch.Tensor):
return value.device
return None
def _move_result_to_device(obj: Any, device: Any) -> Any:
if device is None:
return obj
if isinstance(obj, torch.Tensor):
return obj.to(device) if obj.device != device else obj
if isinstance(obj, dict):
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
if isinstance(obj, list):
return [_move_result_to_device(v, device) for v in obj]
if isinstance(obj, tuple):
return tuple(_move_result_to_device(v, device) for v in obj)
return obj
# DynamicVRAM models must keep load/offload decisions in host process.
# Child-side CUDA staging here can deadlock before first inference RPC.
if self.is_dynamic():
out = self._call_rpc("inner_model_apply_model", args, kwargs)
return _move_result_to_device(out, _preferred_device())
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
self._ensure_apply_model_headroom(required_bytes)
@@ -360,7 +388,8 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs)
return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
return _move_result_to_device(out, _preferred_device())
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
keys = self._call_rpc("model_state_dict", filter_prefix)
@@ -526,6 +555,13 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def memory_required(self, input_shape: Any) -> Any:
return self._call_rpc("memory_required", input_shape)
def get_operation_state(self) -> Dict[str, Any]:
state = self._call_rpc("get_operation_state")
return state if isinstance(state, dict) else {}
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
return bool(self._call_rpc("wait_for_idle", timeout_ms))
def is_dynamic(self) -> bool:
return bool(self._call_rpc("is_dynamic"))
@@ -771,6 +807,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
class _InnerModelProxy:
def __init__(self, parent: ModelPatcherProxy):
self._parent = parent
self._model_sampling = None
def __getattr__(self, name: str) -> Any:
if name.startswith("_"):
@@ -793,7 +830,11 @@ class _InnerModelProxy:
manage_lifecycle=False,
)
if name == "model_sampling":
return self._parent._call_rpc("get_model_object", "model_sampling")
if self._model_sampling is None:
self._model_sampling = self._parent._call_rpc(
"get_model_object", "model_sampling"
)
return self._model_sampling
if name == "extra_conds_shapes":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds_shapes", a, k

View File

@@ -2,8 +2,12 @@
# RPC server for ModelPatcher isolation (child process)
from __future__ import annotations
import asyncio
import gc
import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Optional, List
try:
@@ -43,12 +47,191 @@ from comfy.isolation.proxies.base import (
logger = logging.getLogger(__name__)
@dataclass
class _OperationState:
lease: threading.Lock = field(default_factory=threading.Lock)
active_count: int = 0
active_by_method: dict[str, int] = field(default_factory=dict)
total_operations: int = 0
last_method: Optional[str] = None
last_started_ts: Optional[float] = None
last_ended_ts: Optional[float] = None
last_elapsed_ms: Optional[float] = None
last_error: Optional[str] = None
last_thread_id: Optional[int] = None
last_loop_id: Optional[int] = None
class ModelPatcherRegistry(BaseRegistry[Any]):
_type_prefix = "model"
def __init__(self) -> None:
super().__init__()
self._pending_cleanup_ids: set[str] = set()
self._operation_states: dict[str, _OperationState] = {}
self._operation_state_cv = threading.Condition(self._lock)
def _get_or_create_operation_state(self, instance_id: str) -> _OperationState:
state = self._operation_states.get(instance_id)
if state is None:
state = _OperationState()
self._operation_states[instance_id] = state
return state
def _begin_operation(self, instance_id: str, method_name: str) -> tuple[float, float]:
start_epoch = time.time()
start_perf = time.perf_counter()
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
state.active_count += 1
state.active_by_method[method_name] = (
state.active_by_method.get(method_name, 0) + 1
)
state.total_operations += 1
state.last_method = method_name
state.last_started_ts = start_epoch
state.last_thread_id = threading.get_ident()
try:
state.last_loop_id = id(asyncio.get_running_loop())
except RuntimeError:
state.last_loop_id = None
logger.debug(
"ISO:registry_op_start instance_id=%s method=%s start_ts=%.6f thread=%s loop=%s",
instance_id,
method_name,
start_epoch,
threading.get_ident(),
state.last_loop_id,
)
return start_epoch, start_perf
def _end_operation(
self,
instance_id: str,
method_name: str,
start_perf: float,
error: Optional[BaseException] = None,
) -> None:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
state.active_count = max(0, state.active_count - 1)
if method_name in state.active_by_method:
remaining = state.active_by_method[method_name] - 1
if remaining <= 0:
state.active_by_method.pop(method_name, None)
else:
state.active_by_method[method_name] = remaining
state.last_ended_ts = end_epoch
state.last_elapsed_ms = elapsed_ms
state.last_error = None if error is None else repr(error)
if state.active_count == 0:
self._operation_state_cv.notify_all()
logger.debug(
"ISO:registry_op_end instance_id=%s method=%s end_ts=%.6f elapsed_ms=%.3f error=%s",
instance_id,
method_name,
end_epoch,
elapsed_ms,
None if error is None else type(error).__name__,
)
def _run_operation_with_lease(self, instance_id: str, method_name: str, fn):
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
lease = state.lease
with lease:
_, start_perf = self._begin_operation(instance_id, method_name)
try:
result = fn()
except Exception as exc:
self._end_operation(instance_id, method_name, start_perf, error=exc)
raise
self._end_operation(instance_id, method_name, start_perf)
return result
def _snapshot_operation_state(self, instance_id: str) -> dict[str, Any]:
with self._operation_state_cv:
state = self._operation_states.get(instance_id)
if state is None:
return {
"instance_id": instance_id,
"active_count": 0,
"active_methods": [],
"total_operations": 0,
"last_method": None,
"last_started_ts": None,
"last_ended_ts": None,
"last_elapsed_ms": None,
"last_error": None,
"last_thread_id": None,
"last_loop_id": None,
}
return {
"instance_id": instance_id,
"active_count": state.active_count,
"active_methods": sorted(state.active_by_method.keys()),
"total_operations": state.total_operations,
"last_method": state.last_method,
"last_started_ts": state.last_started_ts,
"last_ended_ts": state.last_ended_ts,
"last_elapsed_ms": state.last_elapsed_ms,
"last_error": state.last_error,
"last_thread_id": state.last_thread_id,
"last_loop_id": state.last_loop_id,
}
def unregister_sync(self, instance_id: str) -> None:
with self._operation_state_cv:
instance = self._registry.pop(instance_id, None)
if instance is not None:
self._id_map.pop(id(instance), None)
self._pending_cleanup_ids.discard(instance_id)
self._operation_states.pop(instance_id, None)
self._operation_state_cv.notify_all()
async def get_operation_state(self, instance_id: str) -> dict[str, Any]:
return self._snapshot_operation_state(instance_id)
async def get_all_operation_states(self) -> dict[str, dict[str, Any]]:
with self._operation_state_cv:
ids = sorted(self._operation_states.keys())
return {instance_id: self._snapshot_operation_state(instance_id) for instance_id in ids}
async def wait_for_idle(self, instance_id: str, timeout_ms: int = 0) -> bool:
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
with self._operation_state_cv:
while True:
active = self._operation_states.get(instance_id)
if active is None or active.active_count == 0:
return True
if deadline is None:
self._operation_state_cv.wait()
continue
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
self._operation_state_cv.wait(timeout=remaining)
async def wait_all_idle(self, timeout_ms: int = 0) -> bool:
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
with self._operation_state_cv:
while True:
has_active = any(
state.active_count > 0 for state in self._operation_states.values()
)
if not has_active:
return True
if deadline is None:
self._operation_state_cv.wait()
continue
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
self._operation_state_cv.wait(timeout=remaining)
async def clone(self, instance_id: str) -> str:
instance = self._get_instance(instance_id)
@@ -73,8 +256,14 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
)
registry = ModelSamplingRegistry()
sampling_id = registry.register(result)
# Preserve identity when upstream already returned a proxy. Re-registering
# a proxy object creates proxy-of-proxy call chains.
if isinstance(result, ModelSamplingProxy):
sampling_id = result._instance_id
else:
sampling_id = registry.register(result)
return ModelSamplingProxy(sampling_id, registry)
return detach_if_grad(result)
async def get_model_options(self, instance_id: str) -> dict:
@@ -163,7 +352,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return self._get_instance(instance_id).lowvram_patch_counter()
async def memory_required(self, instance_id: str, input_shape: Any) -> Any:
return self._get_instance(instance_id).memory_required(input_shape)
return self._run_operation_with_lease(
instance_id,
"memory_required",
lambda: self._get_instance(instance_id).memory_required(input_shape),
)
async def is_dynamic(self, instance_id: str) -> bool:
instance = self._get_instance(instance_id)
@@ -186,7 +379,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return None
async def model_dtype(self, instance_id: str) -> Any:
return self._get_instance(instance_id).model_dtype()
return self._run_operation_with_lease(
instance_id,
"model_dtype",
lambda: self._get_instance(instance_id).model_dtype(),
)
async def model_patches_to(self, instance_id: str, device: Any) -> Any:
return self._get_instance(instance_id).model_patches_to(device)
@@ -198,8 +395,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
extra_memory: Any,
force_patch_weights: bool = False,
) -> Any:
return self._get_instance(instance_id).partially_load(
device, extra_memory, force_patch_weights=force_patch_weights
return self._run_operation_with_lease(
instance_id,
"partially_load",
lambda: self._get_instance(instance_id).partially_load(
device, extra_memory, force_patch_weights=force_patch_weights
),
)
async def partially_unload(
@@ -209,8 +410,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
memory_to_free: int = 0,
force_patch_weights: bool = False,
) -> int:
return self._get_instance(instance_id).partially_unload(
device_to, memory_to_free, force_patch_weights
return self._run_operation_with_lease(
instance_id,
"partially_unload",
lambda: self._get_instance(instance_id).partially_unload(
device_to, memory_to_free, force_patch_weights
),
)
async def load(
@@ -221,8 +426,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
force_patch_weights: bool = False,
full_load: bool = False,
) -> None:
self._get_instance(instance_id).load(
device_to, lowvram_model_memory, force_patch_weights, full_load
self._run_operation_with_lease(
instance_id,
"load",
lambda: self._get_instance(instance_id).load(
device_to, lowvram_model_memory, force_patch_weights, full_load
),
)
async def patch_model(
@@ -233,20 +442,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
load_weights: bool = True,
force_patch_weights: bool = False,
) -> None:
try:
self._get_instance(instance_id).patch_model(
device_to, lowvram_model_memory, load_weights, force_patch_weights
)
except AttributeError as e:
logger.error(
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
)
return
def _invoke() -> None:
try:
self._get_instance(instance_id).patch_model(
device_to, lowvram_model_memory, load_weights, force_patch_weights
)
except AttributeError as e:
logger.error(
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
)
return
self._run_operation_with_lease(instance_id, "patch_model", _invoke)
async def unpatch_model(
self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True
) -> None:
self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights)
self._run_operation_with_lease(
instance_id,
"unpatch_model",
lambda: self._get_instance(instance_id).unpatch_model(
device_to, unpatch_weights
),
)
async def detach(self, instance_id: str, unpatch_all: bool = True) -> None:
self._get_instance(instance_id).detach(unpatch_all)
@@ -262,26 +480,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
self._get_instance(instance_id).pre_run()
async def cleanup(self, instance_id: str) -> None:
try:
instance = self._get_instance(instance_id)
except Exception:
logger.debug(
"ModelPatcher cleanup requested for missing instance %s",
instance_id,
exc_info=True,
)
return
def _invoke() -> None:
try:
instance = self._get_instance(instance_id)
except Exception:
logger.debug(
"ModelPatcher cleanup requested for missing instance %s",
instance_id,
exc_info=True,
)
return
try:
instance.cleanup()
finally:
with self._lock:
self._pending_cleanup_ids.add(instance_id)
gc.collect()
try:
instance.cleanup()
finally:
with self._lock:
self._pending_cleanup_ids.add(instance_id)
gc.collect()
self._run_operation_with_lease(instance_id, "cleanup", _invoke)
def sweep_pending_cleanup(self) -> int:
removed = 0
with self._lock:
with self._operation_state_cv:
pending_ids = list(self._pending_cleanup_ids)
self._pending_cleanup_ids.clear()
for instance_id in pending_ids:
@@ -289,17 +510,21 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
if instance is None:
continue
self._id_map.pop(id(instance), None)
self._operation_states.pop(instance_id, None)
removed += 1
self._operation_state_cv.notify_all()
gc.collect()
return removed
def purge_all(self) -> int:
with self._lock:
with self._operation_state_cv:
removed = len(self._registry)
self._registry.clear()
self._id_map.clear()
self._pending_cleanup_ids.clear()
self._operation_states.clear()
self._operation_state_cv.notify_all()
gc.collect()
return removed
@@ -743,17 +968,52 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def inner_model_memory_required(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.memory_required(*args, **kwargs)
return self._run_operation_with_lease(
instance_id,
"inner_model_memory_required",
lambda: self._get_instance(instance_id).model.memory_required(
*args, **kwargs
),
)
async def inner_model_extra_conds_shapes(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs)
return self._run_operation_with_lease(
instance_id,
"inner_model_extra_conds_shapes",
lambda: self._get_instance(instance_id).model.extra_conds_shapes(
*args, **kwargs
),
)
async def inner_model_extra_conds(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
def _invoke() -> Any:
result = self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
try:
import torch
import comfy.conds
except Exception:
return result
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
if isinstance(obj, comfy.conds.CONDRegular):
return type(obj)(_to_cpu(obj.cond))
return obj
return _to_cpu(result)
return self._run_operation_with_lease(instance_id, "inner_model_extra_conds", _invoke)
async def inner_model_state_dict(
self, instance_id: str, args: tuple, kwargs: dict
@@ -767,82 +1027,177 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def inner_model_apply_model(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
target = getattr(instance, "load_device", None)
if target is None and args and hasattr(args[0], "device"):
target = args[0].device
elif target is None:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
def _invoke() -> Any:
import torch
def _move(obj):
if target is None:
instance = self._get_instance(instance_id)
target = getattr(instance, "load_device", None)
if target is None and args and hasattr(args[0], "device"):
target = args[0].device
elif target is None:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
def _move(obj):
if target is None:
return obj
if isinstance(obj, (tuple, list)):
return type(obj)(_move(o) for o in obj)
if hasattr(obj, "to"):
return obj.to(target)
return obj
if isinstance(obj, (tuple, list)):
return type(obj)(_move(o) for o in obj)
if hasattr(obj, "to"):
return obj.to(target)
return obj
moved_args = tuple(_move(a) for a in args)
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
result = instance.model.apply_model(*moved_args, **moved_kwargs)
return detach_if_grad(_move(result))
moved_args = tuple(_move(a) for a in args)
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
result = instance.model.apply_model(*moved_args, **moved_kwargs)
moved_result = detach_if_grad(_move(result))
# DynamicVRAM + isolation: returning CUDA tensors across RPC can stall
# at the transport boundary. Marshal dynamic-path results as CPU and let
# the proxy restore device placement in the child process.
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(instance_id, "inner_model_apply_model", _invoke)
async def process_latent_in(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).model.process_latent_in(*args, **kwargs)
)
import torch
def _invoke() -> Any:
instance = self._get_instance(instance_id)
result = detach_if_grad(instance.model.process_latent_in(*args, **kwargs))
# DynamicVRAM + isolation: returning CUDA tensors across RPC can stall
# at the transport boundary. Marshal dynamic-path results as CPU and let
# the proxy restore placement when needed.
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(result)
return result
return self._run_operation_with_lease(instance_id, "process_latent_in", _invoke)
async def process_latent_out(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
result = instance.model.process_latent_out(*args, **kwargs)
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
return detach_if_grad(result.to(target))
except Exception:
logger.debug(
"process_latent_out: failed to move result to target device",
exc_info=True,
)
return detach_if_grad(result)
import torch
def _invoke() -> Any:
instance = self._get_instance(instance_id)
result = instance.model.process_latent_out(*args, **kwargs)
moved_result = None
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
moved_result = detach_if_grad(result.to(target))
except Exception:
logger.debug(
"process_latent_out: failed to move result to target device",
exc_info=True,
)
if moved_result is None:
moved_result = detach_if_grad(result)
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(instance_id, "process_latent_out", _invoke)
async def scale_latent_inpaint(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
result = instance.model.scale_latent_inpaint(*args, **kwargs)
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
return detach_if_grad(result.to(target))
except Exception:
logger.debug(
"scale_latent_inpaint: failed to move result to target device",
exc_info=True,
)
return detach_if_grad(result)
import torch
def _invoke() -> Any:
instance = self._get_instance(instance_id)
result = instance.model.scale_latent_inpaint(*args, **kwargs)
moved_result = None
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
moved_result = detach_if_grad(result.to(target))
except Exception:
logger.debug(
"scale_latent_inpaint: failed to move result to target device",
exc_info=True,
)
if moved_result is None:
moved_result = detach_if_grad(result)
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(
instance_id, "scale_latent_inpaint", _invoke
)
async def load_lora(
self,

View File

@@ -3,6 +3,9 @@ from __future__ import annotations
import asyncio
import logging
import os
import threading
import time
from typing import Any
from comfy.isolation.proxies.base import (
@@ -16,6 +19,22 @@ from comfy.isolation.proxies.base import (
logger = logging.getLogger(__name__)
def _describe_value(obj: Any) -> str:
try:
import torch
except Exception:
torch = None
try:
if torch is not None and isinstance(obj, torch.Tensor):
return (
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
)
except Exception:
pass
return "%s(id=%s)" % (type(obj).__name__, id(obj))
def _prefer_device(*tensors: Any) -> Any:
try:
import torch
@@ -49,6 +68,24 @@ def _to_device(obj: Any, device: Any) -> Any:
return obj
def _to_cpu_for_rpc(obj: Any) -> Any:
try:
import torch
except Exception:
return obj
if isinstance(obj, torch.Tensor):
t = obj.detach() if obj.requires_grad else obj
if t.is_cuda:
return t.to("cpu")
return t
if isinstance(obj, (list, tuple)):
converted = [_to_cpu_for_rpc(x) for x in obj]
return type(obj)(converted) if isinstance(obj, tuple) else converted
if isinstance(obj, dict):
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
return obj
class ModelSamplingRegistry(BaseRegistry[Any]):
_type_prefix = "modelsampling"
@@ -199,14 +236,70 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
rpc = self._get_rpc()
method = getattr(rpc, method_name)
result = method(self._instance_id, *args)
timeout_ms = self._rpc_timeout_ms()
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
call_id = "%s:%s:%s:%.6f" % (
self._instance_id,
method_name,
thread_id,
start_perf,
)
logger.debug(
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
method_name,
self._instance_id,
call_id,
start_epoch,
thread_id,
timeout_ms,
)
if asyncio.iscoroutine(result):
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(result)
out = run_coro_in_new_loop(result)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(result)
return result
out = loop.run_until_complete(result)
else:
out = result
logger.debug(
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
method_name,
self._instance_id,
call_id,
_describe_value(out),
)
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
method_name,
self._instance_id,
call_id,
elapsed_ms,
thread_id,
)
logger.debug(
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
method_name,
self._instance_id,
call_id,
)
return out
@staticmethod
def _rpc_timeout_ms() -> int:
raw = os.environ.get(
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
)
try:
timeout_ms = int(raw)
except ValueError:
timeout_ms = 30000
return max(1, timeout_ms)
@property
def sigma_min(self) -> Any:
@@ -235,10 +328,24 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
def noise_scaling(
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
) -> Any:
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise)
preferred_device = _prefer_device(noise, latent_image)
out = self._call(
"noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(noise),
_to_cpu_for_rpc(latent_image),
max_denoise,
)
return _to_device(out, preferred_device)
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
return self._call("inverse_noise_scaling", sigma, latent)
preferred_device = _prefer_device(latent)
out = self._call(
"inverse_noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(latent),
)
return _to_device(out, preferred_device)
def timestep(self, sigma: Any) -> Any:
return self._call("timestep", sigma)

View File

@@ -2,9 +2,11 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import logging
import os
import threading
import time
import weakref
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
@@ -119,6 +121,24 @@ def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
class BaseProxy(Generic[T]):
_registry_class: type = BaseRegistry # type: ignore[type-arg]
__module__: str = "comfy.isolation.proxies.base"
_TIMEOUT_RPC_METHODS = frozenset(
{
"partially_load",
"partially_unload",
"load",
"patch_model",
"unpatch_model",
"inner_model_apply_model",
"memory_required",
"model_dtype",
"inner_model_memory_required",
"inner_model_extra_conds_shapes",
"inner_model_extra_conds",
"process_latent_in",
"process_latent_out",
"scale_latent_inpaint",
}
)
def __init__(
self,
@@ -148,39 +168,89 @@ class BaseProxy(Generic[T]):
)
return self._rpc_caller
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
if method_name not in self._TIMEOUT_RPC_METHODS:
return None
try:
timeout_ms = int(
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
)
except ValueError:
timeout_ms = 120000
return max(1, timeout_ms)
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
rpc = self._get_rpc()
method = getattr(rpc, method_name)
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
coro = method(self._instance_id, *args, **kwargs)
if timeout_ms is not None:
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
# If we are already in the global loop, we can't block on it?
# Actually, this method is synchronous (__getattr__ -> lambda).
# If called from async context in main loop, we need to handle that.
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
# We are in the main loop. We cannot await/block here if we are just a sync function.
# But proxies are often called from sync code.
# If called from sync code in main loop, creating a new loop is bad.
# But we can't await `coro`.
# This implies proxies MUST be awaited if called from async context?
# Existing code used `run_coro_in_new_loop` which is weird.
# Let's trust that if we are in a thread (RuntimeError on get_running_loop),
# we use run_coroutine_threadsafe.
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result()
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
try:
running_loop = asyncio.get_running_loop()
loop_id: Optional[int] = id(running_loop)
except RuntimeError:
loop_id = None
logger.debug(
"ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
"thread=%s loop=%s timeout_ms=%s",
self.__class__.__name__,
method_name,
self._instance_id,
start_epoch,
thread_id,
loop_id,
timeout_ms,
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
try:
curr_loop = asyncio.get_running_loop()
if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
except concurrent.futures.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
finally:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
"elapsed_ms=%.3f thread=%s loop=%s",
self.__class__.__name__,
method_name,
self._instance_id,
end_epoch,
elapsed_ms,
thread_id,
loop_id,
)
def __getstate__(self) -> Dict[str, Any]:
return {"_instance_id": self._instance_id}