clear GPU cache

This commit is contained in:
snicolast
2025-10-03 10:52:10 +13:00
parent 00d2353584
commit 9d3e4f0817

View File

@@ -1,11 +1,17 @@
import gc
import os
import sys
import tempfile
import threading
from functools import wraps
from typing import Any, Dict, Tuple
import numpy as np
#simple in-memory cache for loaded models to avoid re-initializing weights
_MODEL_CACHE: Dict[Tuple[str, str, str, bool, bool], Any] = {}
_CACHE_LOCK = threading.RLock()
_UNLOAD_HOOK_INSTALLED = False
def _resolve_device(device: str):
try:
@@ -28,10 +34,14 @@ def _get_tts2_model(config_path: str,
device: str,
use_cuda_kernel: bool,
use_fp16: bool):
_install_unload_hook()
key = (os.path.abspath(config_path), os.path.abspath(model_dir), device, bool(use_cuda_kernel), bool(use_fp16))
model = _MODEL_CACHE.get(key)
if model is not None:
return model
with _CACHE_LOCK:
cached_model = _MODEL_CACHE.get(key)
if cached_model is not None:
return cached_model
base_dir = os.path.dirname(os.path.abspath(__file__))
ext_root = os.path.dirname(base_dir)
@@ -57,8 +67,165 @@ def _get_tts2_model(config_path: str,
use_cuda_kernel=use_cuda_kernel,
use_deepspeed=False,
)
_MODEL_CACHE[key] = model
return model
with _CACHE_LOCK:
existing = _MODEL_CACHE.get(key)
if existing is None:
_MODEL_CACHE[key] = model
cached_model = model
else:
cached_model = existing
return cached_model
def _flush_device_caches():
try:
import torch
except Exception:
torch = None
if torch is not None:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except Exception:
pass
try:
if hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)():
torch.xpu.empty_cache()
except Exception:
pass
try:
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch.mps.empty_cache()
except Exception:
pass
try:
if hasattr(torch, "npu") and getattr(torch.npu, "is_available", lambda: False)():
torch.npu.empty_cache()
except Exception:
pass
try:
if hasattr(torch, "mlu") and getattr(torch.mlu, "is_available", lambda: False)():
torch.mlu.empty_cache()
except Exception:
pass
gc.collect()
def _teardown_model(model):
try:
import torch
except Exception:
torch = None
module_attrs = [
"gpt",
"semantic_model",
"semantic_codec",
"s2mel",
"campplus_model",
"bigvgan",
"qwen_emo",
]
for attr in module_attrs:
comp = getattr(model, attr, None)
if comp is None:
continue
if torch is not None and hasattr(comp, "to"):
try:
comp.to("cpu")
except Exception:
pass
try:
delattr(model, attr)
except Exception:
setattr(model, attr, None)
tensor_attrs = [
"semantic_mean",
"semantic_std",
]
for attr in tensor_attrs:
value = getattr(model, attr, None)
if value is None:
continue
if torch is not None and hasattr(value, "detach"):
try:
value = value.detach().cpu()
except Exception:
pass
setattr(model, attr, None)
for attr in ("emo_matrix", "spk_matrix"):
if hasattr(model, attr):
setattr(model, attr, None)
cache_attrs = [
"cache_spk_cond",
"cache_s2mel_style",
"cache_s2mel_prompt",
"cache_spk_audio_prompt",
"cache_emo_cond",
"cache_emo_audio_prompt",
"cache_mel",
]
for attr in cache_attrs:
if hasattr(model, attr):
setattr(model, attr, None)
for attr in ("extract_features", "normalizer", "tokenizer", "mel_fn"):
if hasattr(model, attr):
setattr(model, attr, None)
def _dispose_cached_models() -> bool:
with _CACHE_LOCK:
if not _MODEL_CACHE:
return False
cached_items = list(_MODEL_CACHE.items())
_MODEL_CACHE.clear()
for _, model in cached_items:
try:
_teardown_model(model)
except Exception:
pass
_flush_device_caches()
return True
def unload_cached_models() -> bool:
"""Expose manual cache invalidation for other extensions."""
return _dispose_cached_models()
def _install_unload_hook():
global _UNLOAD_HOOK_INSTALLED
if _UNLOAD_HOOK_INSTALLED:
return
try:
import comfy.model_management as mm
except Exception:
return
if getattr(mm.unload_all_models, "_indextts2_hook", False):
_UNLOAD_HOOK_INSTALLED = True
return
original = mm.unload_all_models
@wraps(original)
def wrapper(*args, **kwargs):
_dispose_cached_models()
return original(*args, **kwargs)
wrapper._indextts2_hook = True
mm.unload_all_models = wrapper
_UNLOAD_HOOK_INSTALLED = True
def _audio_to_temp_wav(audio: Any) -> Tuple[str, int, bool]:
@@ -159,6 +326,8 @@ def _save_wav(path: str, wav_cn: np.ndarray, sr: int):
wf.writeframes(interleaved)
_install_unload_hook()
class IndexTTS2Simple:
@classmethod
def INPUT_TYPES(cls):