mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 00:59:59 +00:00
feat(isolation): DynamicVRAM compatibility for process isolation
DynamicVRAM's on-demand model loading/offloading conflicted with process isolation in three ways: RPC tensor transport stalls from mid-call GPU offload, race conditions between model lifecycle and active RPC operations, and false positive memory leak detection from changed finalizer patterns. - Marshal CUDA tensors to CPU before RPC transport for dynamic models - Add operation state tracking + quiescence waits at workflow boundaries - Distinguish proxy reference release from actual leaks in cleanup_models_gc - Fix init order: DynamicVRAM must initialize before isolation proxies - Add RPC timeouts to prevent indefinite hangs on model unavailability - Prevent proxy-of-proxy chains from DynamicVRAM model reload cycles - Add torch.device/torch.dtype serializers for new DynamicVRAM RPC paths - Guard isolation overhead so non-isolated workflows are unaffected - Migrate env var to PYISOLATE_CHILD
This commit is contained in:
@@ -3,6 +3,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from comfy.isolation.proxies.base import (
|
||||
@@ -16,6 +19,22 @@ from comfy.isolation.proxies.base import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _describe_value(obj: Any) -> str:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
torch = None
|
||||
try:
|
||||
if torch is not None and isinstance(obj, torch.Tensor):
|
||||
return (
|
||||
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
|
||||
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return "%s(id=%s)" % (type(obj).__name__, id(obj))
|
||||
|
||||
|
||||
def _prefer_device(*tensors: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
@@ -49,6 +68,24 @@ def _to_device(obj: Any, device: Any) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def _to_cpu_for_rpc(obj: Any) -> Any:
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return obj
|
||||
if isinstance(obj, torch.Tensor):
|
||||
t = obj.detach() if obj.requires_grad else obj
|
||||
if t.is_cuda:
|
||||
return t.to("cpu")
|
||||
return t
|
||||
if isinstance(obj, (list, tuple)):
|
||||
converted = [_to_cpu_for_rpc(x) for x in obj]
|
||||
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
||||
if isinstance(obj, dict):
|
||||
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class ModelSamplingRegistry(BaseRegistry[Any]):
|
||||
_type_prefix = "modelsampling"
|
||||
|
||||
@@ -199,14 +236,70 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
rpc = self._get_rpc()
|
||||
method = getattr(rpc, method_name)
|
||||
result = method(self._instance_id, *args)
|
||||
timeout_ms = self._rpc_timeout_ms()
|
||||
start_epoch = time.time()
|
||||
start_perf = time.perf_counter()
|
||||
thread_id = threading.get_ident()
|
||||
call_id = "%s:%s:%s:%.6f" % (
|
||||
self._instance_id,
|
||||
method_name,
|
||||
thread_id,
|
||||
start_perf,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
start_epoch,
|
||||
thread_id,
|
||||
timeout_ms,
|
||||
)
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return run_coro_in_new_loop(result)
|
||||
out = run_coro_in_new_loop(result)
|
||||
except RuntimeError:
|
||||
loop = get_thread_loop()
|
||||
return loop.run_until_complete(result)
|
||||
return result
|
||||
out = loop.run_until_complete(result)
|
||||
else:
|
||||
out = result
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
_describe_value(out),
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
elapsed_ms,
|
||||
thread_id,
|
||||
)
|
||||
logger.debug(
|
||||
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
|
||||
method_name,
|
||||
self._instance_id,
|
||||
call_id,
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _rpc_timeout_ms() -> int:
|
||||
raw = os.environ.get(
|
||||
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
|
||||
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
|
||||
)
|
||||
try:
|
||||
timeout_ms = int(raw)
|
||||
except ValueError:
|
||||
timeout_ms = 30000
|
||||
return max(1, timeout_ms)
|
||||
|
||||
@property
|
||||
def sigma_min(self) -> Any:
|
||||
@@ -235,10 +328,24 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
||||
def noise_scaling(
|
||||
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
|
||||
) -> Any:
|
||||
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise)
|
||||
preferred_device = _prefer_device(noise, latent_image)
|
||||
out = self._call(
|
||||
"noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(noise),
|
||||
_to_cpu_for_rpc(latent_image),
|
||||
max_denoise,
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
|
||||
return self._call("inverse_noise_scaling", sigma, latent)
|
||||
preferred_device = _prefer_device(latent)
|
||||
out = self._call(
|
||||
"inverse_noise_scaling",
|
||||
_to_cpu_for_rpc(sigma),
|
||||
_to_cpu_for_rpc(latent),
|
||||
)
|
||||
return _to_device(out, preferred_device)
|
||||
|
||||
def timestep(self, sigma: Any) -> Any:
|
||||
return self._call("timestep", sigma)
|
||||
|
||||
Reference in New Issue
Block a user