import os import sys import tempfile from typing import Any, Dict, Tuple import numpy as np #simple in-memory cache for loaded models to avoid re-initializing weights _MODEL_CACHE: Dict[Tuple[str, str, str, bool, bool], Any] = {} def _resolve_device(device: str): try: import torch except Exception: return "cpu" if device and device not in ("auto", ""): return device if torch.cuda.is_available(): return "cuda:0" if hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(): return "xpu" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" return "cpu" def _get_tts2_model(config_path: str, model_dir: str, device: str, use_cuda_kernel: bool, use_fp16: bool): key = (os.path.abspath(config_path), os.path.abspath(model_dir), device, bool(use_cuda_kernel), bool(use_fp16)) model = _MODEL_CACHE.get(key) if model is not None: return model base_dir = os.path.dirname(os.path.abspath(__file__)) ext_root = os.path.dirname(base_dir) if ext_root not in sys.path: sys.path.insert(0, ext_root) #quiet down transformers advisory warnings (e.g., GenerationMixin notice) try: from transformers.utils import logging as hf_logging hf_logging.set_verbosity_error() except Exception: pass from indextts.infer_v2 import IndexTTS2 eff_fp16 = use_fp16 and device.startswith("cuda") model = IndexTTS2( cfg_path=config_path, model_dir=model_dir, use_fp16=eff_fp16, device=device, use_cuda_kernel=use_cuda_kernel, use_deepspeed=False, ) _MODEL_CACHE[key] = model return model def _audio_to_temp_wav(audio: Any) -> Tuple[str, int, bool]: sr = None data = None if isinstance(audio, str) and os.path.exists(audio): return audio, 0, False # use existing path, no cleanup if isinstance(audio, (tuple, list)): cand_ints = [x for x in audio if isinstance(x, (int, np.integer))] cand_arrays = [x for x in audio if hasattr(x, "shape")] if len(cand_ints) >= 1 and len(cand_arrays) >= 1: sr = int(cand_ints[0]) data = cand_arrays[0] elif len(audio) == 2: a, b = audio if isinstance(a, (int, np.integer)) and hasattr(b, "shape"): sr, data = int(a), b elif isinstance(b, (int, np.integer)) and hasattr(a, "shape"): sr, data = int(b), a if sr is None and isinstance(audio, dict): sr = audio.get("sample_rate") or audio.get("sr") or audio.get("rate") for key in ("waveform", "samples", "audio", "data"): if key in audio: data = audio[key] break if sr is None or data is None: raise ValueError("Invalid AUDIO input. Expected (sample_rate:int, numpy_array)") if hasattr(data, "cpu"): data = data.cpu().numpy() wav = np.asarray(data) if wav.ndim == 1: wav = wav[None, :] # (1, N) elif wav.ndim == 2: ch_dim = 0 if wav.shape[0] <= 8 and wav.shape[0] <= wav.shape[1] else 1 if wav.shape[1] <= 8 else 0 if ch_dim == 1: wav = np.transpose(wav, (1, 0)) # (N, C) -> (C, N) elif wav.ndim >= 3: sizes = list(wav.shape) sample_axis = int(np.argmax(sizes)) axes = [i for i in range(wav.ndim) if i != sample_axis] + [sample_axis] wav = np.transpose(wav, axes) c = int(np.prod(wav.shape[:-1])) wav = np.reshape(wav, (c, wav.shape[-1])) else: raise ValueError("AUDIO array must be 1D or 2D (samples[, channels])") if np.issubdtype(wav.dtype, np.integer): info = np.iinfo(wav.dtype) denom = float(max(abs(info.min), abs(info.max))) or 32767.0 wav = wav.astype(np.float32) / denom else: wav = np.clip(wav.astype(np.float32), -1.0, 1.0) fd, tmp_path = tempfile.mkstemp(suffix=".wav", prefix="indextts2_prompt_") os.close(fd) _save_wav(tmp_path, wav, int(sr)) return tmp_path, int(sr), True def _save_wav(path: str, wav_cn: np.ndarray, sr: int): """Save numpy waveform to WAV PCM16 without requiring torchaudio. Expects wav_cn as (channels, samples) float32 in [-1, 1]. """ wav_cn = np.clip(wav_cn, -1.0, 1.0) pcm = (wav_cn * 32767.0).astype(np.int16) try: import soundfile as sf if pcm.ndim == 1: sf.write(path, pcm, sr, subtype="PCM_16") else: sf.write(path, np.transpose(pcm, (1, 0)), sr, subtype="PCM_16") return except Exception: pass import wave import contextlib if pcm.ndim == 1: n_channels = 1 interleaved = pcm.tobytes() n_frames = pcm.shape[0] else: n_channels = int(pcm.shape[0]) n_frames = int(pcm.shape[1]) interleaved = np.transpose(pcm, (1, 0)).tobytes() with contextlib.closing(wave.open(path, "wb")) as wf: wf.setnchannels(n_channels) wf.setsampwidth(2) # 16-bit wf.setframerate(int(sr)) wf.writeframes(interleaved) class IndexTTS2Simple: @classmethod def INPUT_TYPES(cls): return { "required": { "audio": ("AUDIO",), "text": ("STRING", {"multiline": True}), "emotion_control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05}), }, "optional": { "emotion_audio": ("AUDIO",), "emotion_vector": ("EMOTION_VECTOR",), }, } RETURN_TYPES = ("AUDIO", "STRING") FUNCTION = "synthesize" CATEGORY = "Audio/IndexTTS" def synthesize(self, audio, text: str, emotion_control_weight: float, emotion_audio=None, emotion_vector=None): if not isinstance(text, str) or len(text.strip()) == 0: raise ValueError("Text is empty. Please provide text to synthesize.") prompt_path, _, need_cleanup = _audio_to_temp_wav(audio) emo_path = None emo_need_cleanup = False if emotion_audio is not None: try: emo_path, _, emo_need_cleanup = _audio_to_temp_wav(emotion_audio) except Exception: emo_path, emo_need_cleanup = None, False base_dir = os.path.dirname(os.path.abspath(__file__)) ext_root = os.path.dirname(base_dir) resolved_model_dir = os.path.join(ext_root, "checkpoints") resolved_config = os.path.join(resolved_model_dir, "config.yaml") if not os.path.isfile(resolved_config): raise FileNotFoundError(f"Config file not found: {resolved_config}") if not os.path.isdir(resolved_model_dir): raise FileNotFoundError(f"Model directory not found: {resolved_model_dir}") resolved_device = _resolve_device("auto") tts2 = _get_tts2_model( config_path=resolved_config, model_dir=resolved_model_dir, device=resolved_device, use_cuda_kernel=False, use_fp16=True, ) emo_alpha = max(0.0, min(1.0, float(emotion_control_weight))) emo_vector = None ui_msgs = [] if emotion_vector is not None: try: vec = list(emotion_vector) vec = [max(0.0, float(v)) for v in vec][:8] while len(vec) < 8: vec.append(0.0) emo_vector = vec emo_audio_prompt = prompt_path if emo_path is not None: ui_msgs.append("Emotion source: vectors (second audio ignored)") else: ui_msgs.append("Emotion source: vectors") except Exception: emo_vector = None emo_audio_prompt = emo_path if emo_path else prompt_path else: emo_audio_prompt = emo_path if emo_path else prompt_path if emo_path is not None: ui_msgs.append("Emotion source: second audio") else: ui_msgs.append("Emotion source: original audio") try: result = tts2.infer( spk_audio_prompt=prompt_path, text=text, output_path=None, emo_audio_prompt=emo_audio_prompt, emo_alpha=emo_alpha, emo_vector=emo_vector, verbose=False, interval_silence=200, ) finally: #clean up temp files try: if need_cleanup and prompt_path and os.path.exists(prompt_path): os.remove(prompt_path) if emo_need_cleanup and emo_path and os.path.exists(emo_path): os.remove(emo_path) except Exception: pass if not isinstance(result, (tuple, list)) or len(result) != 2: #defensive: if the upstream API changes unexpectedly raise RuntimeError("IndexTTS2 returned an unexpected result format") sr, wav = result if hasattr(wav, "cpu"): wav = wav.cpu().numpy() wav = np.asarray(wav) if wav.dtype == np.int16: wav = (wav.astype(np.float32) / 32767.0) elif wav.dtype != np.float32: wav = wav.astype(np.float32) try: import torch except Exception as e: raise RuntimeError(f"PyTorch is required to return AUDIO to ComfyUI: {e}") mono = wav if mono.ndim == 2: if mono.shape[0] <= 8 and mono.shape[1] > mono.shape[0]: mono = mono.mean(axis=0) else: mono = mono.mean(axis=-1) elif mono.ndim > 2: mono = mono.reshape(-1, mono.shape[-1]).mean(axis=0) if mono.ndim != 1: mono = mono.flatten() waveform = torch.from_numpy(mono[None, None, :].astype(np.float32)) #(B=1, C=1, N) info_text = "\n".join(ui_msgs) if ui_msgs else "" return ({"sample_rate": int(sr), "waveform": waveform}, info_text)