mirror of
https://github.com/snicolast/ComfyUI-IndexTTS2.git
synced 2026-05-01 12:11:22 +00:00
clear GPU cache
This commit is contained in:
@@ -1,11 +1,17 @@
|
|||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from functools import wraps
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
#simple in-memory cache for loaded models to avoid re-initializing weights
|
#simple in-memory cache for loaded models to avoid re-initializing weights
|
||||||
_MODEL_CACHE: Dict[Tuple[str, str, str, bool, bool], Any] = {}
|
_MODEL_CACHE: Dict[Tuple[str, str, str, bool, bool], Any] = {}
|
||||||
|
_CACHE_LOCK = threading.RLock()
|
||||||
|
_UNLOAD_HOOK_INSTALLED = False
|
||||||
|
|
||||||
def _resolve_device(device: str):
|
def _resolve_device(device: str):
|
||||||
try:
|
try:
|
||||||
@@ -28,10 +34,14 @@ def _get_tts2_model(config_path: str,
|
|||||||
device: str,
|
device: str,
|
||||||
use_cuda_kernel: bool,
|
use_cuda_kernel: bool,
|
||||||
use_fp16: bool):
|
use_fp16: bool):
|
||||||
|
_install_unload_hook()
|
||||||
|
|
||||||
key = (os.path.abspath(config_path), os.path.abspath(model_dir), device, bool(use_cuda_kernel), bool(use_fp16))
|
key = (os.path.abspath(config_path), os.path.abspath(model_dir), device, bool(use_cuda_kernel), bool(use_fp16))
|
||||||
model = _MODEL_CACHE.get(key)
|
|
||||||
if model is not None:
|
with _CACHE_LOCK:
|
||||||
return model
|
cached_model = _MODEL_CACHE.get(key)
|
||||||
|
if cached_model is not None:
|
||||||
|
return cached_model
|
||||||
|
|
||||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
ext_root = os.path.dirname(base_dir)
|
ext_root = os.path.dirname(base_dir)
|
||||||
@@ -57,8 +67,165 @@ def _get_tts2_model(config_path: str,
|
|||||||
use_cuda_kernel=use_cuda_kernel,
|
use_cuda_kernel=use_cuda_kernel,
|
||||||
use_deepspeed=False,
|
use_deepspeed=False,
|
||||||
)
|
)
|
||||||
_MODEL_CACHE[key] = model
|
with _CACHE_LOCK:
|
||||||
return model
|
existing = _MODEL_CACHE.get(key)
|
||||||
|
if existing is None:
|
||||||
|
_MODEL_CACHE[key] = model
|
||||||
|
cached_model = model
|
||||||
|
else:
|
||||||
|
cached_model = existing
|
||||||
|
return cached_model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_device_caches():
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except Exception:
|
||||||
|
torch = None
|
||||||
|
|
||||||
|
if torch is not None:
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if hasattr(torch, "npu") and getattr(torch.npu, "is_available", lambda: False)():
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if hasattr(torch, "mlu") and getattr(torch.mlu, "is_available", lambda: False)():
|
||||||
|
torch.mlu.empty_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def _teardown_model(model):
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except Exception:
|
||||||
|
torch = None
|
||||||
|
|
||||||
|
module_attrs = [
|
||||||
|
"gpt",
|
||||||
|
"semantic_model",
|
||||||
|
"semantic_codec",
|
||||||
|
"s2mel",
|
||||||
|
"campplus_model",
|
||||||
|
"bigvgan",
|
||||||
|
"qwen_emo",
|
||||||
|
]
|
||||||
|
for attr in module_attrs:
|
||||||
|
comp = getattr(model, attr, None)
|
||||||
|
if comp is None:
|
||||||
|
continue
|
||||||
|
if torch is not None and hasattr(comp, "to"):
|
||||||
|
try:
|
||||||
|
comp.to("cpu")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
delattr(model, attr)
|
||||||
|
except Exception:
|
||||||
|
setattr(model, attr, None)
|
||||||
|
|
||||||
|
tensor_attrs = [
|
||||||
|
"semantic_mean",
|
||||||
|
"semantic_std",
|
||||||
|
]
|
||||||
|
for attr in tensor_attrs:
|
||||||
|
value = getattr(model, attr, None)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
if torch is not None and hasattr(value, "detach"):
|
||||||
|
try:
|
||||||
|
value = value.detach().cpu()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
setattr(model, attr, None)
|
||||||
|
|
||||||
|
for attr in ("emo_matrix", "spk_matrix"):
|
||||||
|
if hasattr(model, attr):
|
||||||
|
setattr(model, attr, None)
|
||||||
|
|
||||||
|
cache_attrs = [
|
||||||
|
"cache_spk_cond",
|
||||||
|
"cache_s2mel_style",
|
||||||
|
"cache_s2mel_prompt",
|
||||||
|
"cache_spk_audio_prompt",
|
||||||
|
"cache_emo_cond",
|
||||||
|
"cache_emo_audio_prompt",
|
||||||
|
"cache_mel",
|
||||||
|
]
|
||||||
|
for attr in cache_attrs:
|
||||||
|
if hasattr(model, attr):
|
||||||
|
setattr(model, attr, None)
|
||||||
|
|
||||||
|
for attr in ("extract_features", "normalizer", "tokenizer", "mel_fn"):
|
||||||
|
if hasattr(model, attr):
|
||||||
|
setattr(model, attr, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _dispose_cached_models() -> bool:
|
||||||
|
with _CACHE_LOCK:
|
||||||
|
if not _MODEL_CACHE:
|
||||||
|
return False
|
||||||
|
cached_items = list(_MODEL_CACHE.items())
|
||||||
|
_MODEL_CACHE.clear()
|
||||||
|
|
||||||
|
for _, model in cached_items:
|
||||||
|
try:
|
||||||
|
_teardown_model(model)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
_flush_device_caches()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def unload_cached_models() -> bool:
|
||||||
|
"""Expose manual cache invalidation for other extensions."""
|
||||||
|
return _dispose_cached_models()
|
||||||
|
|
||||||
|
|
||||||
|
def _install_unload_hook():
|
||||||
|
global _UNLOAD_HOOK_INSTALLED
|
||||||
|
if _UNLOAD_HOOK_INSTALLED:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import comfy.model_management as mm
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
if getattr(mm.unload_all_models, "_indextts2_hook", False):
|
||||||
|
_UNLOAD_HOOK_INSTALLED = True
|
||||||
|
return
|
||||||
|
|
||||||
|
original = mm.unload_all_models
|
||||||
|
|
||||||
|
@wraps(original)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
_dispose_cached_models()
|
||||||
|
return original(*args, **kwargs)
|
||||||
|
|
||||||
|
wrapper._indextts2_hook = True
|
||||||
|
mm.unload_all_models = wrapper
|
||||||
|
_UNLOAD_HOOK_INSTALLED = True
|
||||||
|
|
||||||
|
|
||||||
def _audio_to_temp_wav(audio: Any) -> Tuple[str, int, bool]:
|
def _audio_to_temp_wav(audio: Any) -> Tuple[str, int, bool]:
|
||||||
|
|
||||||
@@ -159,6 +326,8 @@ def _save_wav(path: str, wav_cn: np.ndarray, sr: int):
|
|||||||
wf.writeframes(interleaved)
|
wf.writeframes(interleaved)
|
||||||
|
|
||||||
|
|
||||||
|
_install_unload_hook()
|
||||||
|
|
||||||
class IndexTTS2Simple:
|
class IndexTTS2Simple:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
|
|||||||
Reference in New Issue
Block a user