mirror of
https://github.com/snicolast/ComfyUI-IndexTTS2.git
synced 2026-01-26 14:39:44 +00:00
clear GPU cache
This commit is contained in:
@@ -1,11 +1,17 @@
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
#simple in-memory cache for loaded models to avoid re-initializing weights
|
||||
_MODEL_CACHE: Dict[Tuple[str, str, str, bool, bool], Any] = {}
|
||||
_CACHE_LOCK = threading.RLock()
|
||||
_UNLOAD_HOOK_INSTALLED = False
|
||||
|
||||
def _resolve_device(device: str):
|
||||
try:
|
||||
@@ -28,10 +34,14 @@ def _get_tts2_model(config_path: str,
|
||||
device: str,
|
||||
use_cuda_kernel: bool,
|
||||
use_fp16: bool):
|
||||
_install_unload_hook()
|
||||
|
||||
key = (os.path.abspath(config_path), os.path.abspath(model_dir), device, bool(use_cuda_kernel), bool(use_fp16))
|
||||
model = _MODEL_CACHE.get(key)
|
||||
if model is not None:
|
||||
return model
|
||||
|
||||
with _CACHE_LOCK:
|
||||
cached_model = _MODEL_CACHE.get(key)
|
||||
if cached_model is not None:
|
||||
return cached_model
|
||||
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ext_root = os.path.dirname(base_dir)
|
||||
@@ -57,8 +67,165 @@ def _get_tts2_model(config_path: str,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
use_deepspeed=False,
|
||||
)
|
||||
_MODEL_CACHE[key] = model
|
||||
return model
|
||||
with _CACHE_LOCK:
|
||||
existing = _MODEL_CACHE.get(key)
|
||||
if existing is None:
|
||||
_MODEL_CACHE[key] = model
|
||||
cached_model = model
|
||||
else:
|
||||
cached_model = existing
|
||||
return cached_model
|
||||
|
||||
|
||||
|
||||
|
||||
def _flush_device_caches():
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
torch = None
|
||||
|
||||
if torch is not None:
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)():
|
||||
torch.xpu.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(torch, "npu") and getattr(torch.npu, "is_available", lambda: False)():
|
||||
torch.npu.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(torch, "mlu") and getattr(torch.mlu, "is_available", lambda: False)():
|
||||
torch.mlu.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
gc.collect()
|
||||
|
||||
|
||||
def _teardown_model(model):
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
torch = None
|
||||
|
||||
module_attrs = [
|
||||
"gpt",
|
||||
"semantic_model",
|
||||
"semantic_codec",
|
||||
"s2mel",
|
||||
"campplus_model",
|
||||
"bigvgan",
|
||||
"qwen_emo",
|
||||
]
|
||||
for attr in module_attrs:
|
||||
comp = getattr(model, attr, None)
|
||||
if comp is None:
|
||||
continue
|
||||
if torch is not None and hasattr(comp, "to"):
|
||||
try:
|
||||
comp.to("cpu")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
delattr(model, attr)
|
||||
except Exception:
|
||||
setattr(model, attr, None)
|
||||
|
||||
tensor_attrs = [
|
||||
"semantic_mean",
|
||||
"semantic_std",
|
||||
]
|
||||
for attr in tensor_attrs:
|
||||
value = getattr(model, attr, None)
|
||||
if value is None:
|
||||
continue
|
||||
if torch is not None and hasattr(value, "detach"):
|
||||
try:
|
||||
value = value.detach().cpu()
|
||||
except Exception:
|
||||
pass
|
||||
setattr(model, attr, None)
|
||||
|
||||
for attr in ("emo_matrix", "spk_matrix"):
|
||||
if hasattr(model, attr):
|
||||
setattr(model, attr, None)
|
||||
|
||||
cache_attrs = [
|
||||
"cache_spk_cond",
|
||||
"cache_s2mel_style",
|
||||
"cache_s2mel_prompt",
|
||||
"cache_spk_audio_prompt",
|
||||
"cache_emo_cond",
|
||||
"cache_emo_audio_prompt",
|
||||
"cache_mel",
|
||||
]
|
||||
for attr in cache_attrs:
|
||||
if hasattr(model, attr):
|
||||
setattr(model, attr, None)
|
||||
|
||||
for attr in ("extract_features", "normalizer", "tokenizer", "mel_fn"):
|
||||
if hasattr(model, attr):
|
||||
setattr(model, attr, None)
|
||||
|
||||
|
||||
def _dispose_cached_models() -> bool:
|
||||
with _CACHE_LOCK:
|
||||
if not _MODEL_CACHE:
|
||||
return False
|
||||
cached_items = list(_MODEL_CACHE.items())
|
||||
_MODEL_CACHE.clear()
|
||||
|
||||
for _, model in cached_items:
|
||||
try:
|
||||
_teardown_model(model)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_flush_device_caches()
|
||||
return True
|
||||
|
||||
|
||||
def unload_cached_models() -> bool:
|
||||
"""Expose manual cache invalidation for other extensions."""
|
||||
return _dispose_cached_models()
|
||||
|
||||
|
||||
def _install_unload_hook():
|
||||
global _UNLOAD_HOOK_INSTALLED
|
||||
if _UNLOAD_HOOK_INSTALLED:
|
||||
return
|
||||
try:
|
||||
import comfy.model_management as mm
|
||||
except Exception:
|
||||
return
|
||||
if getattr(mm.unload_all_models, "_indextts2_hook", False):
|
||||
_UNLOAD_HOOK_INSTALLED = True
|
||||
return
|
||||
|
||||
original = mm.unload_all_models
|
||||
|
||||
@wraps(original)
|
||||
def wrapper(*args, **kwargs):
|
||||
_dispose_cached_models()
|
||||
return original(*args, **kwargs)
|
||||
|
||||
wrapper._indextts2_hook = True
|
||||
mm.unload_all_models = wrapper
|
||||
_UNLOAD_HOOK_INSTALLED = True
|
||||
|
||||
|
||||
def _audio_to_temp_wav(audio: Any) -> Tuple[str, int, bool]:
|
||||
|
||||
@@ -159,6 +326,8 @@ def _save_wav(path: str, wav_cn: np.ndarray, sr: int):
|
||||
wf.writeframes(interleaved)
|
||||
|
||||
|
||||
_install_unload_hook()
|
||||
|
||||
class IndexTTS2Simple:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
|
||||
Reference in New Issue
Block a user