diff --git a/README.md b/README.md index 5d4237a..e545386 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,9 @@ Lightweight ComfyUI wrapper for IndexTTS 2 (voice cloning + emotion control). Th Original repo: https://github.com/index-tts/index-tts +## Updates +- Update 2025-09-14: Added IndexTTS2 Advanced node exposing sampling, speed, seed, and other generation controls. + Install - Clone this repository to: ComfyUI/custom_nodes/ - In your ComfyUI Python environment: @@ -62,6 +65,11 @@ Nodes - Notes: device auto-detected, FP16 on CUDA, 200 ms pause between segments (fixed), emotion precedence = vector > second audio > original audio +- IndexTTS2 Advanced + - Inputs: same as Simple plus optional overrides for sampling (temperature, top-p, top-k, beams), max tokens, speech speed, interval silence, typical sampling, and seed. + - Notes: defaults mirror the Simple node; change values only when you need reproducible or exploratory behavior. + + - IndexTTS2 Emotion Vector - 8 sliders (0.0-1.4) for: happy, angry, sad, afraid, disgusted, melancholic, surprised, calm - Constraint: sum of sliders must be <= 1.5 (no auto-scaling) diff --git a/__init__.py b/__init__.py index e7e1636..4e0bac3 100644 --- a/__init__.py +++ b/__init__.py @@ -1,15 +1,18 @@ from .nodes.indextts2_node import IndexTTS2Simple +from .nodes.indextts2_node_advanced import IndexTTS2Advanced from .nodes.indextts2_node_emovec import IndexTTS2EmotionVector from .nodes.indextts2_node_emotext import IndexTTS2EmotionFromText NODE_CLASS_MAPPINGS = { "IndexTTS2Simple": IndexTTS2Simple, + "IndexTTS2Advanced": IndexTTS2Advanced, "IndexTTS2EmotionVector": IndexTTS2EmotionVector, "IndexTTS2EmotionFromText": IndexTTS2EmotionFromText, } NODE_DISPLAY_NAME_MAPPINGS = { "IndexTTS2Simple": "IndexTTS2 Simple", + "IndexTTS2Advanced": "IndexTTS2 Advanced", "IndexTTS2EmotionVector": "IndexTTS2 Emotion Vector", "IndexTTS2EmotionFromText": "IndexTTS2 Emotion From Text", } diff --git a/images/overview.png b/images/overview.png index a343f99..72b02a1 100644 Binary files a/images/overview.png and b/images/overview.png differ diff --git a/indextts/infer_v2.py b/indextts/infer_v2.py index 92087f5..a12105b 100644 --- a/indextts/infer_v2.py +++ b/indextts/infer_v2.py @@ -496,6 +496,7 @@ class IndexTTS2: num_beams = generation_kwargs.pop("num_beams", 3) repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0) max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 1500) + speech_speed = float(generation_kwargs.pop("speech_speed", 1.0)) sampling_rate = 22050 wavs = [] @@ -539,7 +540,7 @@ class IndexTTS2: cond_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device), emo_cond_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device), emo_vec=emovec, - do_sample=True, + do_sample=do_sample, top_p=top_p, top_k=top_k, temperature=temperature, @@ -610,7 +611,8 @@ class IndexTTS2: S_infer = self.semantic_codec.quantizer.vq2emb(codes.unsqueeze(1)) S_infer = S_infer.transpose(1, 2) S_infer = S_infer + latent - target_lengths = (code_lens * 1.72).long() + speed_scale = max(0.1, min(3.0, float(speech_speed))) + target_lengths = torch.clamp((code_lens.float() * 1.72 * speed_scale).long(), min=1) cond = self.s2mel.models['length_regulator'](S_infer, ylens=target_lengths, diff --git a/nodes/indextts2_node_advanced.py b/nodes/indextts2_node_advanced.py new file mode 100644 index 0000000..d569d19 --- /dev/null +++ b/nodes/indextts2_node_advanced.py @@ -0,0 +1,350 @@ +import os +import random +import numpy as np + +from .indextts2_node import ( + _audio_to_temp_wav, + _get_tts2_model, + _resolve_device, +) + + +def _coerce_float(value, default, clamp=None, random_bounds=None): + def _maybe_random(): + if not random_bounds: + base = default if isinstance(default, (int, float)) else 0.0 + low = max(0.0, base * 0.5) if base else 0.0 + high = base * 1.5 + 1.0 if base else 1.0 + return random.uniform(low, high) + low, high = random_bounds + return random.uniform(low, high) + + if value is None: + return default + if isinstance(value, (int, float)): + result = float(value) + elif isinstance(value, str): + token = value.strip().lower() + if not token: + return default + if token in {"random", "rand", "randomize"}: + result = float(_maybe_random()) + else: + try: + result = float(token) + except ValueError: + return default + else: + return default + if clamp is not None: + low, high = clamp + if low is not None: + result = max(low, result) + if high is not None: + result = min(high, result) + return result + + +def _coerce_int(value, default, clamp=None, random_bounds=None): + if value is None: + result = default + elif isinstance(value, (int, float)): + result = int(value) + elif isinstance(value, str): + token = value.strip().lower() + if not token: + return default + if token in {"random", "rand", "randomize"}: + if random_bounds is not None: + low, high = random_bounds + else: + base = default if isinstance(default, (int, float)) else 0 + low = max(0, int(base * 0.5)) + high = max(low + 1, int(base * 1.5) + 1) + if high <= low: + high = low + 1 + result = random.randint(int(low), int(high)) + else: + try: + result = int(float(token)) + except ValueError: + return default + else: + return default + if clamp is not None: + low, high = clamp + if low is not None: + result = max(int(low), result) + if high is not None: + result = min(int(high), result) + return result + + +def _coerce_bool(value, default=False): + if isinstance(value, bool): + return value + if value is None: + return default + if isinstance(value, str): + token = value.strip().lower() + if token in {"", "none"}: + return default + if token in {"true", "1", "yes", "on"}: + return True + if token in {"false", "0", "no", "off"}: + return False + return bool(value) + + +class IndexTTS2Advanced: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "audio": ("AUDIO",), + "text": ("STRING", {"multiline": True}), + "emotion_control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05}), + }, + "optional": { + "emotion_audio": ("AUDIO",), + "emotion_vector": ("EMOTION_VECTOR",), + "use_random_style": ("BOOLEAN", {"default": False}), + "interval_silence_ms": ("INT", {"default": 200, "min": 0, "max": 12000, "step": 10}), + "max_text_tokens_per_segment": ("INT", {"default": 120, "min": 0, "max": 2048, "step": 8}), + "seed": ("INT", {"default": -1, "min": -1, "max": 2147483647}), + "do_sample": ("BOOLEAN", {"default": True}), + "temperature": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 5.0, "step": 0.05}), + "top_p": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}), + "top_k": ("INT", {"default": 30, "min": 0, "max": 2048, "step": 1}), + "repetition_penalty": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), + "length_penalty": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 50.0, "step": 0.1}), + "num_beams": ("INT", {"default": 3, "min": 1, "max": 10, "step": 1}), + "max_mel_tokens": ("INT", {"default": 1500, "min": 0, "max": 8192, "step": 8}), + "typical_sampling": ("BOOLEAN", {"default": False}), + "typical_mass": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 2000.0, "step": 0.01}), + "speech_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 4.0, "step": 0.05}), + }, + } + + RETURN_TYPES = ("AUDIO", "STRING") + FUNCTION = "synthesize" + CATEGORY = "Audio/IndexTTS" + + def synthesize(self, + audio, + text: str, + emotion_control_weight: float, + emotion_audio=None, + emotion_vector=None, + use_random_style: bool = False, + interval_silence_ms: int = 200, + max_text_tokens_per_segment: int = 120, + seed: int = -1, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + top_k: int = 30, + repetition_penalty: float = 10.0, + length_penalty: float = 0.0, + num_beams: int = 3, + max_mel_tokens: int = 1500, + typical_sampling: bool = False, + typical_mass: float = 0.9, + speech_speed: float = 1.0): + + if not isinstance(text, str) or len(text.strip()) == 0: + raise ValueError("Text is empty. Please provide text to synthesize.") + + prompt_path, _, need_cleanup = _audio_to_temp_wav(audio) + emo_path = None + emo_need_cleanup = False + if emotion_audio is not None: + try: + emo_path, _, emo_need_cleanup = _audio_to_temp_wav(emotion_audio) + except Exception: + emo_path, emo_need_cleanup = None, False + + seed_value = None + if isinstance(seed, (int, np.integer)) and int(seed) >= 0: + seed_value = int(seed) + random.seed(seed_value) + np.random.seed(seed_value) + + base_dir = os.path.dirname(os.path.abspath(__file__)) + ext_root = os.path.dirname(base_dir) + resolved_model_dir = os.path.join(ext_root, "checkpoints") + resolved_config = os.path.join(resolved_model_dir, "config.yaml") + + if not os.path.isfile(resolved_config): + raise FileNotFoundError(f"Config file not found: {resolved_config}") + if not os.path.isdir(resolved_model_dir): + raise FileNotFoundError(f"Model directory not found: {resolved_model_dir}") + + resolved_device = _resolve_device("auto") + tts2 = _get_tts2_model( + config_path=resolved_config, + model_dir=resolved_model_dir, + device=resolved_device, + use_cuda_kernel=False, + use_fp16=True, + ) + + torch_mod = None + + def _ensure_torch(): + nonlocal torch_mod + if torch_mod is None: + try: + import torch as torch_lib # type: ignore + except Exception as exc: + raise RuntimeError(f"PyTorch is required for IndexTTS2 Advanced: {exc}") + torch_mod = torch_lib + return torch_mod + + seed_info = "random" + if seed_value is not None: + torch_lib = _ensure_torch() + torch_lib.manual_seed(seed_value) + if torch_lib.cuda.is_available(): + torch_lib.cuda.manual_seed_all(seed_value) + if hasattr(torch_lib, "xpu") and callable(getattr(torch_lib.xpu, "is_available", None)) and torch_lib.xpu.is_available(): + torch_lib.xpu.manual_seed_all(seed_value) + if hasattr(torch_lib.backends, "mps") and torch_lib.backends.mps.is_available(): + try: + torch_lib.manual_seed(seed_value) + except Exception: + pass + seed_info = str(seed_value) + + emo_alpha = max(0.0, min(1.0, float(emotion_control_weight))) + emo_audio_prompt = emo_path if emo_path else prompt_path + ui_msgs = [] + + emo_vector_arg = None + if emotion_vector is not None: + try: + vec = [max(0.0, float(v)) for v in list(emotion_vector)[:8]] + while len(vec) < 8: + vec.append(0.0) + emo_vector_arg = vec + emo_audio_prompt = prompt_path + if emo_path is not None: + ui_msgs.append("Emotion source: vectors (second audio ignored)") + else: + ui_msgs.append("Emotion source: vectors") + except Exception: + emo_vector_arg = None + + if emo_vector_arg is None: + if emo_path is not None: + ui_msgs.append("Emotion source: second audio") + else: + ui_msgs.append("Emotion source: original audio") + + use_random_style = _coerce_bool(use_random_style, False) + if use_random_style: + ui_msgs.append("Emotion source: random preset mix") + + do_sample = _coerce_bool(do_sample, True) + typical_sampling = _coerce_bool(typical_sampling, False) + + interval_silence_ms = _coerce_int(interval_silence_ms, 200, clamp=(0, 12000)) + max_text_tokens_per_segment = _coerce_int(max_text_tokens_per_segment, 120, clamp=(0, 2048)) + if max_text_tokens_per_segment <= 0: + max_text_tokens_per_segment = 120 + + top_k = _coerce_int(top_k, 30, clamp=(0, 2048)) + num_beams = max(1, _coerce_int(num_beams, 3, clamp=(1, 128))) + max_mel_tokens = _coerce_int(max_mel_tokens, 1500, clamp=(1, 8192)) + + temperature = _coerce_float(temperature, 0.8, clamp=(0.0, 5.0), random_bounds=(0.6, 1.4)) + if temperature < 1e-4: + temperature = 1e-4 + top_p = _coerce_float(top_p, 0.8, clamp=(0.0, 1.0), random_bounds=(0.5, 0.95)) + repetition_penalty = _coerce_float(repetition_penalty, 10.0, clamp=(0.0, 50.0)) + length_penalty = _coerce_float(length_penalty, 0.0, clamp=(-10.0, 50.0)) + typical_mass = _coerce_float(typical_mass, 0.9, clamp=(0.0, 0.99), random_bounds=(0.5, 0.95)) + if typical_mass <= 0.0: + typical_mass = 0.9 + + speech_speed = _coerce_float(speech_speed, 1.0, clamp=(0.25, 4.0), random_bounds=(0.6, 1.4)) + + generation_kwargs = { + "do_sample": bool(do_sample), + "top_p": top_p, + "top_k": top_k, + "temperature": temperature, + "length_penalty": float(length_penalty), + "num_beams": num_beams, + "repetition_penalty": repetition_penalty, + "max_mel_tokens": max_mel_tokens, + "typical_sampling": bool(typical_sampling), + "typical_mass": typical_mass, + "speech_speed": speech_speed, + } + try: + result = tts2.infer( + spk_audio_prompt=prompt_path, + text=text, + output_path=None, + emo_audio_prompt=emo_audio_prompt, + emo_alpha=emo_alpha, + emo_vector=emo_vector_arg, + use_random=bool(use_random_style), + interval_silence=interval_silence_ms, + verbose=False, + max_text_tokens_per_segment=max_text_tokens_per_segment, + **generation_kwargs, + ) + finally: + try: + if need_cleanup and prompt_path and os.path.exists(prompt_path): + os.remove(prompt_path) + if emo_need_cleanup and emo_path and os.path.exists(emo_path): + os.remove(emo_path) + except Exception: + pass + + if not isinstance(result, (tuple, list)) or len(result) != 2: + raise RuntimeError("IndexTTS2 returned an unexpected result format") + + sr, wav = result + torch_lib = _ensure_torch() + if hasattr(wav, "cpu"): + wav = wav.cpu().numpy() + wav = np.asarray(wav) + + if wav.dtype == np.int16: + wav = wav.astype(np.float32) / 32767.0 + elif wav.dtype != np.float32: + wav = wav.astype(np.float32) + + mono = wav + if mono.ndim == 2: + if mono.shape[0] <= 8 and mono.shape[1] > mono.shape[0]: + mono = mono.mean(axis=0) + else: + mono = mono.mean(axis=-1) + elif mono.ndim > 2: + mono = mono.reshape(-1, mono.shape[-1]).mean(axis=0) + if mono.ndim != 1: + mono = mono.flatten() + + waveform = torch_lib.from_numpy(mono[None, None, :].astype(np.float32)) + + info_lines = [] + if ui_msgs: + info_lines.extend(ui_msgs) + info_lines.append(f"Seed: {seed_info}") + if do_sample: + info_lines.append(f"Sampling: temp={temperature:.2f}, top_p={top_p:.2f}, top_k={top_k}") + else: + info_lines.append(f"Beam search: num_beams={num_beams}") + info_lines.append(f"Repetition penalty={repetition_penalty:.2f}, max_mel_tokens={max_mel_tokens}") + info_lines.append(f"Speech speed scale={speech_speed:.2f}, interval_silence_ms={interval_silence_ms}") + if typical_sampling: + info_lines.append(f"Typical sampling mass={typical_mass:.2f}") + + info_text = "\n".join(info_lines) + return ({"sample_rate": int(sr), "waveform": waveform}, info_text) +