mirror of
https://github.com/snicolast/ComfyUI-IndexTTS2.git
synced 2026-04-28 18:51:34 +00:00
Sync infer_v2 with upstream streaming updates
This commit is contained in:
@@ -28,13 +28,36 @@ from indextts.s2mel.modules.campplus.DTDNN import CAMPPlus
|
||||
from indextts.s2mel.modules.audio import mel_spectrogram
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
try:
|
||||
from modelscope import AutoModelForCausalLM
|
||||
except Exception:
|
||||
AutoModelForCausalLM = None
|
||||
from modelscope import AutoModelForCausalLM
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
import safetensors
|
||||
from transformers import SeamlessM4TFeatureExtractor
|
||||
HF_AUTH_TOKEN = (
|
||||
os.getenv("HUGGINGFACE_TOKEN")
|
||||
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
||||
or os.getenv("HF_TOKEN")
|
||||
or os.getenv("HF_HUB_TOKEN")
|
||||
or os.getenv("HUGGINGFACEHUB_TOKEN")
|
||||
)
|
||||
|
||||
|
||||
def _hf_download(repo_id, filename, **kwargs):
|
||||
if HF_AUTH_TOKEN and "token" not in kwargs:
|
||||
kwargs["token"] = HF_AUTH_TOKEN
|
||||
try:
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
|
||||
except HfHubHTTPError as err:
|
||||
status = getattr(err.response, "status_code", None)
|
||||
if status == 401:
|
||||
print(f"[IndexTTS2] Download for {repo_id}/{filename} failed with 401; retrying anonymously.")
|
||||
kwargs = dict(kwargs)
|
||||
kwargs.pop("token", None)
|
||||
kwargs["token"] = False
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
|
||||
raise
|
||||
|
||||
|
||||
import random
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -79,9 +102,7 @@ class IndexTTS2:
|
||||
self.dtype = torch.float16 if self.use_fp16 else None
|
||||
self.stop_mel_token = self.cfg.gpt.stop_mel_token
|
||||
|
||||
# Lazy init for QwenEmotion to avoid requiring `modelscope` when not using emo_text
|
||||
self.qwen_emo = None
|
||||
self.qwen_emo_path = os.path.join(self.model_dir, self.cfg.qwen_emo_path)
|
||||
self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
|
||||
|
||||
self.gpt = UnifiedVoice(**self.cfg.gpt)
|
||||
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
|
||||
@@ -105,34 +126,15 @@ class IndexTTS2:
|
||||
if self.use_cuda_kernel:
|
||||
# preload the CUDA kernel for BigVGAN
|
||||
try:
|
||||
from indextts.BigVGAN.alias_free_activation.cuda import load
|
||||
from indextts.s2mel.modules.bigvgan.alias_free_activation.cuda import activation1d
|
||||
|
||||
anti_alias_activation_cuda = load.load()
|
||||
print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda)
|
||||
except:
|
||||
print(">> Preload custom CUDA kernel for BigVGAN", activation1d.anti_alias_activation_cuda)
|
||||
except Exception as e:
|
||||
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
|
||||
print(f"{e!r}")
|
||||
self.use_cuda_kernel = False
|
||||
|
||||
#Prefer local w2v-bert-2.0 if present; otherwise allow HF download
|
||||
try:
|
||||
local_w2v_dir = os.path.join(self.model_dir, "w2v-bert-2.0")
|
||||
if os.path.isdir(local_w2v_dir):
|
||||
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
|
||||
local_w2v_dir,
|
||||
local_files_only=True,
|
||||
)
|
||||
print(">> W2V-BERT feature extractor loaded from:", local_w2v_dir)
|
||||
else:
|
||||
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
|
||||
"facebook/w2v-bert-2.0"
|
||||
)
|
||||
print(">> W2V-BERT feature extractor loaded from HF repo: facebook/w2v-bert-2.0")
|
||||
except Exception as e:
|
||||
# Fallback to HF repo if local load fails unexpectedly
|
||||
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
|
||||
"facebook/w2v-bert-2.0"
|
||||
)
|
||||
print(">> W2V-BERT feature extractor: local load failed; using HF repo. Error:", e)
|
||||
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
||||
self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model(
|
||||
os.path.join(self.model_dir, self.cfg.w2v_stat))
|
||||
self.semantic_model = self.semantic_model.to(self.device)
|
||||
@@ -141,28 +143,11 @@ class IndexTTS2:
|
||||
self.semantic_std = self.semantic_std.to(self.device)
|
||||
|
||||
semantic_codec = build_semantic_codec(self.cfg.semantic_codec)
|
||||
semantic_code_ckpt = os.path.join(self.model_dir, "semantic_codec/model.safetensors")
|
||||
ckpt_to_load = semantic_code_ckpt
|
||||
if not os.path.isfile(semantic_code_ckpt):
|
||||
try:
|
||||
# Attempt to download into HF cache and use from there
|
||||
hf_cache_dir = os.path.join(self.model_dir, 'hf_cache')
|
||||
os.makedirs(hf_cache_dir, exist_ok=True)
|
||||
ckpt_to_load = hf_hub_download(
|
||||
repo_id="amphion/MaskGCT",
|
||||
filename="semantic_codec/model.safetensors",
|
||||
cache_dir=hf_cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
print(">> semantic_codec weights downloaded to cache:", ckpt_to_load)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"semantic_codec/model.safetensors not found and download failed: {e}"
|
||||
)
|
||||
safetensors.torch.load_model(semantic_codec, ckpt_to_load)
|
||||
semantic_code_ckpt = _hf_download("amphion/MaskGCT", "semantic_codec/model.safetensors")
|
||||
safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
|
||||
self.semantic_codec = semantic_codec.to(self.device)
|
||||
self.semantic_codec.eval()
|
||||
print('>> semantic_codec weights restored from: {}'.format(ckpt_to_load))
|
||||
print('>> semantic_codec weights restored from: {}'.format(semantic_code_ckpt))
|
||||
|
||||
s2mel_path = os.path.join(self.model_dir, self.cfg.s2mel_checkpoint)
|
||||
s2mel = MyModel(self.cfg.s2mel, use_gpt_latent=True)
|
||||
@@ -179,31 +164,14 @@ class IndexTTS2:
|
||||
self.s2mel.eval()
|
||||
print(">> s2mel weights restored from:", s2mel_path)
|
||||
|
||||
# load campplus_model (local first; fallback to HF cache)
|
||||
campplus_ckpt_path = os.path.join(self.model_dir, "campplus_cn_common.bin")
|
||||
campplus_ckpt_to_load = campplus_ckpt_path
|
||||
if not os.path.isfile(campplus_ckpt_path):
|
||||
try:
|
||||
hf_cache_dir = os.path.join(self.model_dir, 'hf_cache')
|
||||
os.makedirs(hf_cache_dir, exist_ok=True)
|
||||
campplus_ckpt_to_load = hf_hub_download(
|
||||
repo_id="funasr/campplus",
|
||||
filename="campplus_cn_common.bin",
|
||||
cache_dir=hf_cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
print(">> campplus_model weights downloaded to cache:", campplus_ckpt_to_load)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"campplus_cn_common.bin not found and download failed: {e}"
|
||||
)
|
||||
# load campplus_model
|
||||
campplus_ckpt_path = _hf_download("funasr/campplus", "campplus_cn_common.bin")
|
||||
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
|
||||
campplus_model.load_state_dict(torch.load(campplus_ckpt_to_load, map_location="cpu"))
|
||||
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
|
||||
self.campplus_model = campplus_model.to(self.device)
|
||||
self.campplus_model.eval()
|
||||
print(">> campplus_model weights restored from:", campplus_ckpt_to_load)
|
||||
print(">> campplus_model weights restored from:", campplus_ckpt_path)
|
||||
|
||||
# BigVGAN: prefer local dir if complete; otherwise fallback to HF repo
|
||||
bigvgan_local_dir = os.path.join(self.model_dir, self.cfg.vocoder.name)
|
||||
bigvgan_config = os.path.join(bigvgan_local_dir, "config.json")
|
||||
bigvgan_weights = os.path.join(bigvgan_local_dir, "bigvgan_generator.pt")
|
||||
@@ -212,15 +180,26 @@ class IndexTTS2:
|
||||
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_local_dir, use_cuda_kernel=self.use_cuda_kernel)
|
||||
bigvgan_source = bigvgan_local_dir
|
||||
else:
|
||||
# fallback to HF repo id (cached under HF_HUB_CACHE)
|
||||
repo_id = "nvidia/bigvgan_v2_22khz_80band_256x"
|
||||
print(">> BigVGAN local files missing or incomplete; loading from HF repo:", repo_id)
|
||||
self.bigvgan = bigvgan.BigVGAN.from_pretrained(repo_id, use_cuda_kernel=self.use_cuda_kernel)
|
||||
print('>> BigVGAN local files missing or incomplete; loading from HF repo:', repo_id)
|
||||
kwargs = {}
|
||||
if HF_AUTH_TOKEN:
|
||||
kwargs['token'] = HF_AUTH_TOKEN
|
||||
try:
|
||||
self.bigvgan = bigvgan.BigVGAN.from_pretrained(repo_id, use_cuda_kernel=self.use_cuda_kernel, **kwargs)
|
||||
except HfHubHTTPError as err:
|
||||
status = getattr(err.response, "status_code", None)
|
||||
if status == 401:
|
||||
print(f"[IndexTTS2] BigVGAN download failed with 401; retrying anonymously.")
|
||||
kwargs = {"token": False}
|
||||
self.bigvgan = bigvgan.BigVGAN.from_pretrained(repo_id, use_cuda_kernel=self.use_cuda_kernel, **kwargs)
|
||||
else:
|
||||
raise
|
||||
bigvgan_source = repo_id
|
||||
self.bigvgan = self.bigvgan.to(self.device)
|
||||
self.bigvgan.remove_weight_norm()
|
||||
self.bigvgan.eval()
|
||||
print(">> bigvgan weights restored from:", bigvgan_source)
|
||||
print('>> bigvgan weights restored from:', bigvgan_source)
|
||||
|
||||
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
|
||||
self.normalizer = TextNormalizer()
|
||||
@@ -332,6 +311,20 @@ class IndexTTS2:
|
||||
code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
|
||||
return codes, code_lens
|
||||
|
||||
def interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
|
||||
"""
|
||||
Silences to be insert between generated segments.
|
||||
"""
|
||||
|
||||
if not wavs or interval_silence <= 0:
|
||||
return wavs
|
||||
|
||||
# get channel_size
|
||||
channel_size = wavs[0].size(0)
|
||||
# get silence tensor
|
||||
sil_dur = int(sampling_rate * interval_silence / 1000.0)
|
||||
return torch.zeros(channel_size, sil_dur)
|
||||
|
||||
def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
|
||||
"""
|
||||
Insert silences between generated segments.
|
||||
@@ -359,12 +352,67 @@ class IndexTTS2:
|
||||
if self.gr_progress is not None:
|
||||
self.gr_progress(value, desc=desc)
|
||||
|
||||
def _load_and_cut_audio(self,audio_path,max_audio_length_seconds,verbose=False,sr=None):
|
||||
if not sr:
|
||||
audio, sr = librosa.load(audio_path)
|
||||
else:
|
||||
audio, _ = librosa.load(audio_path,sr=sr)
|
||||
audio = torch.tensor(audio).unsqueeze(0)
|
||||
max_audio_samples = int(max_audio_length_seconds * sr)
|
||||
|
||||
if audio.shape[1] > max_audio_samples:
|
||||
if verbose:
|
||||
print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples")
|
||||
audio = audio[:, :max_audio_samples]
|
||||
return audio, sr
|
||||
|
||||
def normalize_emo_vec(self, emo_vector, apply_bias=True):
|
||||
# apply biased emotion factors for better user experience,
|
||||
# by de-emphasizing emotions that can cause strange results
|
||||
if apply_bias:
|
||||
# [happy, angry, sad, afraid, disgusted, melancholic, surprised, calm]
|
||||
emo_bias = [0.9375, 0.875, 1.0, 1.0, 0.9375, 0.9375, 0.6875, 0.5625]
|
||||
emo_vector = [vec * bias for vec, bias in zip(emo_vector, emo_bias)]
|
||||
|
||||
# the total emotion sum must be 0.8 or less
|
||||
emo_sum = sum(emo_vector)
|
||||
if emo_sum > 0.8:
|
||||
scale_factor = 0.8 / emo_sum
|
||||
emo_vector = [vec * scale_factor for vec in emo_vector]
|
||||
|
||||
return emo_vector
|
||||
|
||||
# 原始推理模式
|
||||
def infer(self, spk_audio_prompt, text, output_path,
|
||||
emo_audio_prompt=None, emo_alpha=1.0,
|
||||
emo_vector=None,
|
||||
use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
|
||||
verbose=False, max_text_tokens_per_segment=120, **generation_kwargs):
|
||||
verbose=False, max_text_tokens_per_segment=120, stream_return=False, more_segment_before=0, **generation_kwargs):
|
||||
if stream_return:
|
||||
return self.infer_generator(
|
||||
spk_audio_prompt, text, output_path,
|
||||
emo_audio_prompt, emo_alpha,
|
||||
emo_vector,
|
||||
use_emo_text, emo_text, use_random, interval_silence,
|
||||
verbose, max_text_tokens_per_segment, stream_return, more_segment_before, **generation_kwargs
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return list(self.infer_generator(
|
||||
spk_audio_prompt, text, output_path,
|
||||
emo_audio_prompt, emo_alpha,
|
||||
emo_vector,
|
||||
use_emo_text, emo_text, use_random, interval_silence,
|
||||
verbose, max_text_tokens_per_segment, stream_return, more_segment_before, **generation_kwargs
|
||||
))[0]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
def infer_generator(self, spk_audio_prompt, text, output_path,
|
||||
emo_audio_prompt=None, emo_alpha=1.0,
|
||||
emo_vector=None,
|
||||
use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
|
||||
verbose=False, max_text_tokens_per_segment=120, stream_return=False, quick_streaming_tokens=0, **generation_kwargs):
|
||||
print(">> starting inference...")
|
||||
self._set_gr_progress(0, "starting inference...")
|
||||
if verbose:
|
||||
@@ -383,12 +431,6 @@ class IndexTTS2:
|
||||
# automatically generate emotion vectors from text prompt
|
||||
if emo_text is None:
|
||||
emo_text = text # use main text prompt
|
||||
if self.qwen_emo is None:
|
||||
if AutoModelForCausalLM is None:
|
||||
raise ImportError(
|
||||
"`modelscope` is required to use emo_text. Install `modelscope` or disable 'use_emo_text'."
|
||||
)
|
||||
self.qwen_emo = QwenEmotion(self.qwen_emo_path)
|
||||
emo_dict = self.qwen_emo.inference(emo_text)
|
||||
print(f"detected emotion vectors from text: {emo_dict}")
|
||||
# convert ordered dict to list of vectors; the order is VERY important!
|
||||
@@ -418,10 +460,8 @@ class IndexTTS2:
|
||||
self.cache_s2mel_style = None
|
||||
self.cache_s2mel_prompt = None
|
||||
self.cache_mel = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
audio, sr = librosa.load(spk_audio_prompt)
|
||||
audio = torch.tensor(audio).unsqueeze(0)
|
||||
torch.cuda.empty_cache()
|
||||
audio,sr = self._load_and_cut_audio(spk_audio_prompt,15,verbose)
|
||||
audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
|
||||
audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
|
||||
|
||||
@@ -474,9 +514,8 @@ class IndexTTS2:
|
||||
if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt:
|
||||
if self.cache_emo_cond is not None:
|
||||
self.cache_emo_cond = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
emo_audio, _ = librosa.load(emo_audio_prompt, sr=16000)
|
||||
torch.cuda.empty_cache()
|
||||
emo_audio, _ = self._load_and_cut_audio(emo_audio_prompt,15,verbose,sr=16000)
|
||||
emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt")
|
||||
emo_input_features = emo_inputs["input_features"]
|
||||
emo_attention_mask = emo_inputs["attention_mask"]
|
||||
@@ -491,14 +530,18 @@ class IndexTTS2:
|
||||
|
||||
self._set_gr_progress(0.1, "text processing...")
|
||||
text_tokens_list = self.tokenizer.tokenize(text)
|
||||
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
|
||||
try:
|
||||
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment, quick_streaming_tokens=quick_streaming_tokens)
|
||||
except TypeError:
|
||||
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
|
||||
segments_count = len(segments)
|
||||
|
||||
text_token_ids = self.tokenizer.convert_tokens_to_ids(text_tokens_list)
|
||||
if self.tokenizer.unk_token_id in text_token_ids:
|
||||
print(f">> Warning: input text contains {text_token_ids.count(self.tokenizer.unk_token_id)} unknown tokens (id={self.tokenizer.unk_token_id}):")
|
||||
print(f" Tokens which can't be decoded: {[token for token, token_id in zip(text_tokens_list, text_token_ids) if token_id == self.tokenizer.unk_token_id]}")
|
||||
print(" Consider updating the BPE model or modifying the text to avoid unknown tokens.")
|
||||
|
||||
print(f" >> Warning: input text contains {text_token_ids.count(self.tokenizer.unk_token_id)} unknown tokens (id={self.tokenizer.unk_token_id}):")
|
||||
print( " Tokens which can't be encoded: ", [t for t, id in zip(text_tokens_list, text_token_ids) if id == self.tokenizer.unk_token_id])
|
||||
print(f" Consider updating the BPE model or modifying the text to avoid unknown tokens.")
|
||||
|
||||
if verbose:
|
||||
print("text_tokens_list:", text_tokens_list)
|
||||
print("segments count:", segments_count)
|
||||
@@ -522,6 +565,7 @@ class IndexTTS2:
|
||||
s2mel_time = 0
|
||||
bigvgan_time = 0
|
||||
has_warned = False
|
||||
silence = None # for stream_return
|
||||
for seg_idx, sent in enumerate(segments):
|
||||
self._set_gr_progress(0.2 + 0.7 * seg_idx / segments_count,
|
||||
f"speech synthesis {seg_idx + 1}/{segments_count}...")
|
||||
@@ -557,7 +601,7 @@ class IndexTTS2:
|
||||
cond_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
|
||||
emo_cond_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
|
||||
emo_vec=emovec,
|
||||
do_sample=do_sample,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
@@ -655,6 +699,11 @@ class IndexTTS2:
|
||||
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
|
||||
# wavs.append(wav[:, :-512])
|
||||
wavs.append(wav.cpu()) # to cpu before saving
|
||||
if stream_return:
|
||||
yield wav.cpu()
|
||||
if silence == None:
|
||||
silence = self.interval_silence(wavs, sampling_rate=sampling_rate, interval_silence=interval_silence)
|
||||
yield silence
|
||||
end_time = time.perf_counter()
|
||||
|
||||
self._set_gr_progress(0.9, "saving audio...")
|
||||
@@ -680,12 +729,16 @@ class IndexTTS2:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
|
||||
print(">> wav file saved to:", output_path)
|
||||
return output_path
|
||||
if stream_return:
|
||||
return None
|
||||
yield output_path
|
||||
else:
|
||||
if stream_return:
|
||||
return None
|
||||
# 返回以符合Gradio的格式要求
|
||||
wav_data = wav.type(torch.int16)
|
||||
wav_data = wav_data.numpy().T
|
||||
return (sampling_rate, wav_data)
|
||||
yield (sampling_rate, wav_data)
|
||||
|
||||
|
||||
def find_most_similar_cosine(query_vector, matrix):
|
||||
@@ -816,4 +869,3 @@ if __name__ == "__main__":
|
||||
|
||||
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False)
|
||||
tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user