Files
ComfyUI-IndexTTS2/nodes/indextts2_node.py
2025-09-12 16:23:19 +12:00

298 lines
10 KiB
Python

import os
import sys
import tempfile
from typing import Any, Dict, Tuple
import numpy as np
#simple in-memory cache for loaded models to avoid re-initializing weights
_MODEL_CACHE: Dict[Tuple[str, str, str, bool, bool], Any] = {}
def _resolve_device(device: str):
try:
import torch
except Exception:
return "cpu"
if device and device not in ("auto", ""):
return device
if torch.cuda.is_available():
return "cuda:0"
if hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)():
return "xpu"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
def _get_tts2_model(config_path: str,
model_dir: str,
device: str,
use_cuda_kernel: bool,
use_fp16: bool):
key = (os.path.abspath(config_path), os.path.abspath(model_dir), device, bool(use_cuda_kernel), bool(use_fp16))
model = _MODEL_CACHE.get(key)
if model is not None:
return model
base_dir = os.path.dirname(os.path.abspath(__file__))
ext_root = os.path.dirname(base_dir)
if ext_root not in sys.path:
sys.path.insert(0, ext_root)
#quiet down transformers advisory warnings (e.g., GenerationMixin notice)
try:
from transformers.utils import logging as hf_logging
hf_logging.set_verbosity_error()
except Exception:
pass
from indextts.infer_v2 import IndexTTS2
eff_fp16 = use_fp16 and device.startswith("cuda")
model = IndexTTS2(
cfg_path=config_path,
model_dir=model_dir,
use_fp16=eff_fp16,
device=device,
use_cuda_kernel=use_cuda_kernel,
use_deepspeed=False,
)
_MODEL_CACHE[key] = model
return model
def _audio_to_temp_wav(audio: Any) -> Tuple[str, int, bool]:
sr = None
data = None
if isinstance(audio, str) and os.path.exists(audio):
return audio, 0, False # use existing path, no cleanup
if isinstance(audio, (tuple, list)):
cand_ints = [x for x in audio if isinstance(x, (int, np.integer))]
cand_arrays = [x for x in audio if hasattr(x, "shape")]
if len(cand_ints) >= 1 and len(cand_arrays) >= 1:
sr = int(cand_ints[0])
data = cand_arrays[0]
elif len(audio) == 2:
a, b = audio
if isinstance(a, (int, np.integer)) and hasattr(b, "shape"):
sr, data = int(a), b
elif isinstance(b, (int, np.integer)) and hasattr(a, "shape"):
sr, data = int(b), a
if sr is None and isinstance(audio, dict):
sr = audio.get("sample_rate") or audio.get("sr") or audio.get("rate")
for key in ("waveform", "samples", "audio", "data"):
if key in audio:
data = audio[key]
break
if sr is None or data is None:
raise ValueError("Invalid AUDIO input. Expected (sample_rate:int, numpy_array)")
if hasattr(data, "cpu"):
data = data.cpu().numpy()
wav = np.asarray(data)
if wav.ndim == 1:
wav = wav[None, :] # (1, N)
elif wav.ndim == 2:
ch_dim = 0 if wav.shape[0] <= 8 and wav.shape[0] <= wav.shape[1] else 1 if wav.shape[1] <= 8 else 0
if ch_dim == 1:
wav = np.transpose(wav, (1, 0)) # (N, C) -> (C, N)
elif wav.ndim >= 3:
sizes = list(wav.shape)
sample_axis = int(np.argmax(sizes))
axes = [i for i in range(wav.ndim) if i != sample_axis] + [sample_axis]
wav = np.transpose(wav, axes)
c = int(np.prod(wav.shape[:-1]))
wav = np.reshape(wav, (c, wav.shape[-1]))
else:
raise ValueError("AUDIO array must be 1D or 2D (samples[, channels])")
if np.issubdtype(wav.dtype, np.integer):
info = np.iinfo(wav.dtype)
denom = float(max(abs(info.min), abs(info.max))) or 32767.0
wav = wav.astype(np.float32) / denom
else:
wav = np.clip(wav.astype(np.float32), -1.0, 1.0)
fd, tmp_path = tempfile.mkstemp(suffix=".wav", prefix="indextts2_prompt_")
os.close(fd)
_save_wav(tmp_path, wav, int(sr))
return tmp_path, int(sr), True
def _save_wav(path: str, wav_cn: np.ndarray, sr: int):
"""Save numpy waveform to WAV PCM16 without requiring torchaudio.
Expects wav_cn as (channels, samples) float32 in [-1, 1].
"""
wav_cn = np.clip(wav_cn, -1.0, 1.0)
pcm = (wav_cn * 32767.0).astype(np.int16)
try:
import soundfile as sf
if pcm.ndim == 1:
sf.write(path, pcm, sr, subtype="PCM_16")
else:
sf.write(path, np.transpose(pcm, (1, 0)), sr, subtype="PCM_16")
return
except Exception:
pass
import wave
import contextlib
if pcm.ndim == 1:
n_channels = 1
interleaved = pcm.tobytes()
n_frames = pcm.shape[0]
else:
n_channels = int(pcm.shape[0])
n_frames = int(pcm.shape[1])
interleaved = np.transpose(pcm, (1, 0)).tobytes()
with contextlib.closing(wave.open(path, "wb")) as wf:
wf.setnchannels(n_channels)
wf.setsampwidth(2) # 16-bit
wf.setframerate(int(sr))
wf.writeframes(interleaved)
class IndexTTS2Simple:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"text": ("STRING", {"multiline": True}),
"emotion_control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05}),
},
"optional": {
"emotion_audio": ("AUDIO",),
"emotion_vector": ("EMOTION_VECTOR",),
},
}
RETURN_TYPES = ("AUDIO", "STRING")
FUNCTION = "synthesize"
CATEGORY = "Audio/IndexTTS"
def synthesize(self,
audio,
text: str,
emotion_control_weight: float,
emotion_audio=None,
emotion_vector=None):
if not isinstance(text, str) or len(text.strip()) == 0:
raise ValueError("Text is empty. Please provide text to synthesize.")
prompt_path, _, need_cleanup = _audio_to_temp_wav(audio)
emo_path = None
emo_need_cleanup = False
if emotion_audio is not None:
try:
emo_path, _, emo_need_cleanup = _audio_to_temp_wav(emotion_audio)
except Exception:
emo_path, emo_need_cleanup = None, False
base_dir = os.path.dirname(os.path.abspath(__file__))
ext_root = os.path.dirname(base_dir)
resolved_model_dir = os.path.join(ext_root, "checkpoints")
resolved_config = os.path.join(resolved_model_dir, "config.yaml")
if not os.path.isfile(resolved_config):
raise FileNotFoundError(f"Config file not found: {resolved_config}")
if not os.path.isdir(resolved_model_dir):
raise FileNotFoundError(f"Model directory not found: {resolved_model_dir}")
resolved_device = _resolve_device("auto")
tts2 = _get_tts2_model(
config_path=resolved_config,
model_dir=resolved_model_dir,
device=resolved_device,
use_cuda_kernel=False,
use_fp16=True,
)
emo_alpha = max(0.0, min(1.0, float(emotion_control_weight)))
emo_vector = None
ui_msgs = []
if emotion_vector is not None:
try:
vec = list(emotion_vector)
vec = [max(0.0, float(v)) for v in vec][:8]
while len(vec) < 8:
vec.append(0.0)
emo_vector = vec
emo_audio_prompt = prompt_path
if emo_path is not None:
ui_msgs.append("Emotion source: vectors (second audio ignored)")
else:
ui_msgs.append("Emotion source: vectors")
except Exception:
emo_vector = None
emo_audio_prompt = emo_path if emo_path else prompt_path
else:
emo_audio_prompt = emo_path if emo_path else prompt_path
if emo_path is not None:
ui_msgs.append("Emotion source: second audio")
else:
ui_msgs.append("Emotion source: original audio")
try:
result = tts2.infer(
spk_audio_prompt=prompt_path,
text=text,
output_path=None,
emo_audio_prompt=emo_audio_prompt,
emo_alpha=emo_alpha,
emo_vector=emo_vector,
verbose=False,
interval_silence=200,
)
finally:
#clean up temp files
try:
if need_cleanup and prompt_path and os.path.exists(prompt_path):
os.remove(prompt_path)
if emo_need_cleanup and emo_path and os.path.exists(emo_path):
os.remove(emo_path)
except Exception:
pass
if not isinstance(result, (tuple, list)) or len(result) != 2:
#defensive: if the upstream API changes unexpectedly
raise RuntimeError("IndexTTS2 returned an unexpected result format")
sr, wav = result
if hasattr(wav, "cpu"):
wav = wav.cpu().numpy()
wav = np.asarray(wav)
if wav.dtype == np.int16:
wav = (wav.astype(np.float32) / 32767.0)
elif wav.dtype != np.float32:
wav = wav.astype(np.float32)
try:
import torch
except Exception as e:
raise RuntimeError(f"PyTorch is required to return AUDIO to ComfyUI: {e}")
mono = wav
if mono.ndim == 2:
if mono.shape[0] <= 8 and mono.shape[1] > mono.shape[0]:
mono = mono.mean(axis=0)
else:
mono = mono.mean(axis=-1)
elif mono.ndim > 2:
mono = mono.reshape(-1, mono.shape[-1]).mean(axis=0)
if mono.ndim != 1:
mono = mono.flatten()
waveform = torch.from_numpy(mono[None, None, :].astype(np.float32)) #(B=1, C=1, N)
info_text = "\n".join(ui_msgs) if ui_msgs else ""
return ({"sample_rate": int(sr), "waveform": waveform}, info_text)