clear GPU cache

This commit is contained in:
snicolast
2025-10-03 10:52:10 +13:00
parent 00d2353584
commit 9d3e4f0817

View File

@@ -1,11 +1,17 @@
import gc
import os import os
import sys import sys
import tempfile import tempfile
import threading
from functools import wraps
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
import numpy as np import numpy as np
#simple in-memory cache for loaded models to avoid re-initializing weights #simple in-memory cache for loaded models to avoid re-initializing weights
_MODEL_CACHE: Dict[Tuple[str, str, str, bool, bool], Any] = {} _MODEL_CACHE: Dict[Tuple[str, str, str, bool, bool], Any] = {}
_CACHE_LOCK = threading.RLock()
_UNLOAD_HOOK_INSTALLED = False
def _resolve_device(device: str): def _resolve_device(device: str):
try: try:
@@ -28,10 +34,14 @@ def _get_tts2_model(config_path: str,
device: str, device: str,
use_cuda_kernel: bool, use_cuda_kernel: bool,
use_fp16: bool): use_fp16: bool):
_install_unload_hook()
key = (os.path.abspath(config_path), os.path.abspath(model_dir), device, bool(use_cuda_kernel), bool(use_fp16)) key = (os.path.abspath(config_path), os.path.abspath(model_dir), device, bool(use_cuda_kernel), bool(use_fp16))
model = _MODEL_CACHE.get(key)
if model is not None: with _CACHE_LOCK:
return model cached_model = _MODEL_CACHE.get(key)
if cached_model is not None:
return cached_model
base_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = os.path.dirname(os.path.abspath(__file__))
ext_root = os.path.dirname(base_dir) ext_root = os.path.dirname(base_dir)
@@ -57,8 +67,165 @@ def _get_tts2_model(config_path: str,
use_cuda_kernel=use_cuda_kernel, use_cuda_kernel=use_cuda_kernel,
use_deepspeed=False, use_deepspeed=False,
) )
_MODEL_CACHE[key] = model with _CACHE_LOCK:
return model existing = _MODEL_CACHE.get(key)
if existing is None:
_MODEL_CACHE[key] = model
cached_model = model
else:
cached_model = existing
return cached_model
def _flush_device_caches():
try:
import torch
except Exception:
torch = None
if torch is not None:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except Exception:
pass
try:
if hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)():
torch.xpu.empty_cache()
except Exception:
pass
try:
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch.mps.empty_cache()
except Exception:
pass
try:
if hasattr(torch, "npu") and getattr(torch.npu, "is_available", lambda: False)():
torch.npu.empty_cache()
except Exception:
pass
try:
if hasattr(torch, "mlu") and getattr(torch.mlu, "is_available", lambda: False)():
torch.mlu.empty_cache()
except Exception:
pass
gc.collect()
def _teardown_model(model):
try:
import torch
except Exception:
torch = None
module_attrs = [
"gpt",
"semantic_model",
"semantic_codec",
"s2mel",
"campplus_model",
"bigvgan",
"qwen_emo",
]
for attr in module_attrs:
comp = getattr(model, attr, None)
if comp is None:
continue
if torch is not None and hasattr(comp, "to"):
try:
comp.to("cpu")
except Exception:
pass
try:
delattr(model, attr)
except Exception:
setattr(model, attr, None)
tensor_attrs = [
"semantic_mean",
"semantic_std",
]
for attr in tensor_attrs:
value = getattr(model, attr, None)
if value is None:
continue
if torch is not None and hasattr(value, "detach"):
try:
value = value.detach().cpu()
except Exception:
pass
setattr(model, attr, None)
for attr in ("emo_matrix", "spk_matrix"):
if hasattr(model, attr):
setattr(model, attr, None)
cache_attrs = [
"cache_spk_cond",
"cache_s2mel_style",
"cache_s2mel_prompt",
"cache_spk_audio_prompt",
"cache_emo_cond",
"cache_emo_audio_prompt",
"cache_mel",
]
for attr in cache_attrs:
if hasattr(model, attr):
setattr(model, attr, None)
for attr in ("extract_features", "normalizer", "tokenizer", "mel_fn"):
if hasattr(model, attr):
setattr(model, attr, None)
def _dispose_cached_models() -> bool:
with _CACHE_LOCK:
if not _MODEL_CACHE:
return False
cached_items = list(_MODEL_CACHE.items())
_MODEL_CACHE.clear()
for _, model in cached_items:
try:
_teardown_model(model)
except Exception:
pass
_flush_device_caches()
return True
def unload_cached_models() -> bool:
"""Expose manual cache invalidation for other extensions."""
return _dispose_cached_models()
def _install_unload_hook():
global _UNLOAD_HOOK_INSTALLED
if _UNLOAD_HOOK_INSTALLED:
return
try:
import comfy.model_management as mm
except Exception:
return
if getattr(mm.unload_all_models, "_indextts2_hook", False):
_UNLOAD_HOOK_INSTALLED = True
return
original = mm.unload_all_models
@wraps(original)
def wrapper(*args, **kwargs):
_dispose_cached_models()
return original(*args, **kwargs)
wrapper._indextts2_hook = True
mm.unload_all_models = wrapper
_UNLOAD_HOOK_INSTALLED = True
def _audio_to_temp_wav(audio: Any) -> Tuple[str, int, bool]: def _audio_to_temp_wav(audio: Any) -> Tuple[str, int, bool]:
@@ -159,6 +326,8 @@ def _save_wav(path: str, wav_cn: np.ndarray, sr: int):
wf.writeframes(interleaved) wf.writeframes(interleaved)
_install_unload_hook()
class IndexTTS2Simple: class IndexTTS2Simple:
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):