feat(isolation): DynamicVRAM compatibility for process isolation

DynamicVRAM's on-demand model loading/offloading conflicted with  process isolation in three ways: RPC tensor transport stalls from mid-call GPU offload, race conditions between model lifecycle and active RPC operations, and false positive memory leak detection from changed finalizer patterns.

- Marshal CUDA tensors to CPU before RPC transport for dynamic models
- Add operation state tracking + quiescence waits at workflow boundaries
- Distinguish proxy reference release from actual leaks in cleanup_models_gc
- Fix init order: DynamicVRAM must initialize before isolation proxies
- Add RPC timeouts to prevent indefinite hangs on model unavailability
- Prevent proxy-of-proxy chains from DynamicVRAM model reload cycles
- Add torch.device/torch.dtype serializers for new DynamicVRAM RPC paths
- Guard isolation overhead so non-isolated workflows are unaffected
- Migrate env var to PYISOLATE_CHILD
This commit is contained in:
John Pollock
2026-03-04 23:48:02 -06:00
parent a0f8784e9f
commit 9250191c65
38 changed files with 94595 additions and 307 deletions

View File

@@ -3,6 +3,9 @@ from __future__ import annotations
import asyncio
import logging
import os
import threading
import time
from typing import Any
from comfy.isolation.proxies.base import (
@@ -16,6 +19,22 @@ from comfy.isolation.proxies.base import (
logger = logging.getLogger(__name__)
def _describe_value(obj: Any) -> str:
try:
import torch
except Exception:
torch = None
try:
if torch is not None and isinstance(obj, torch.Tensor):
return (
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
)
except Exception:
pass
return "%s(id=%s)" % (type(obj).__name__, id(obj))
def _prefer_device(*tensors: Any) -> Any:
try:
import torch
@@ -49,6 +68,24 @@ def _to_device(obj: Any, device: Any) -> Any:
return obj
def _to_cpu_for_rpc(obj: Any) -> Any:
try:
import torch
except Exception:
return obj
if isinstance(obj, torch.Tensor):
t = obj.detach() if obj.requires_grad else obj
if t.is_cuda:
return t.to("cpu")
return t
if isinstance(obj, (list, tuple)):
converted = [_to_cpu_for_rpc(x) for x in obj]
return type(obj)(converted) if isinstance(obj, tuple) else converted
if isinstance(obj, dict):
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
return obj
class ModelSamplingRegistry(BaseRegistry[Any]):
_type_prefix = "modelsampling"
@@ -199,14 +236,70 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
rpc = self._get_rpc()
method = getattr(rpc, method_name)
result = method(self._instance_id, *args)
timeout_ms = self._rpc_timeout_ms()
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
call_id = "%s:%s:%s:%.6f" % (
self._instance_id,
method_name,
thread_id,
start_perf,
)
logger.debug(
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
method_name,
self._instance_id,
call_id,
start_epoch,
thread_id,
timeout_ms,
)
if asyncio.iscoroutine(result):
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(result)
out = run_coro_in_new_loop(result)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(result)
return result
out = loop.run_until_complete(result)
else:
out = result
logger.debug(
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
method_name,
self._instance_id,
call_id,
_describe_value(out),
)
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
method_name,
self._instance_id,
call_id,
elapsed_ms,
thread_id,
)
logger.debug(
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
method_name,
self._instance_id,
call_id,
)
return out
@staticmethod
def _rpc_timeout_ms() -> int:
raw = os.environ.get(
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
)
try:
timeout_ms = int(raw)
except ValueError:
timeout_ms = 30000
return max(1, timeout_ms)
@property
def sigma_min(self) -> Any:
@@ -235,10 +328,24 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
def noise_scaling(
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
) -> Any:
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise)
preferred_device = _prefer_device(noise, latent_image)
out = self._call(
"noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(noise),
_to_cpu_for_rpc(latent_image),
max_denoise,
)
return _to_device(out, preferred_device)
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
return self._call("inverse_noise_scaling", sigma, latent)
preferred_device = _prefer_device(latent)
out = self._call(
"inverse_noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(latent),
)
return _to_device(out, preferred_device)
def timestep(self, sigma: Any) -> Any:
return self._call("timestep", sigma)