Sync infer_v2 with upstream streaming updates

This commit is contained in:
snicolast
2025-10-03 19:08:48 +13:00
parent 9d3e4f0817
commit 586bd77efe

View File

@@ -28,13 +28,36 @@ from indextts.s2mel.modules.campplus.DTDNN import CAMPPlus
from indextts.s2mel.modules.audio import mel_spectrogram
from transformers import AutoTokenizer
try:
from modelscope import AutoModelForCausalLM
except Exception:
AutoModelForCausalLM = None
from modelscope import AutoModelForCausalLM
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
import safetensors
from transformers import SeamlessM4TFeatureExtractor
HF_AUTH_TOKEN = (
os.getenv("HUGGINGFACE_TOKEN")
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
or os.getenv("HF_TOKEN")
or os.getenv("HF_HUB_TOKEN")
or os.getenv("HUGGINGFACEHUB_TOKEN")
)
def _hf_download(repo_id, filename, **kwargs):
if HF_AUTH_TOKEN and "token" not in kwargs:
kwargs["token"] = HF_AUTH_TOKEN
try:
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
except HfHubHTTPError as err:
status = getattr(err.response, "status_code", None)
if status == 401:
print(f"[IndexTTS2] Download for {repo_id}/{filename} failed with 401; retrying anonymously.")
kwargs = dict(kwargs)
kwargs.pop("token", None)
kwargs["token"] = False
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
raise
import random
import torch.nn.functional as F
@@ -79,9 +102,7 @@ class IndexTTS2:
self.dtype = torch.float16 if self.use_fp16 else None
self.stop_mel_token = self.cfg.gpt.stop_mel_token
# Lazy init for QwenEmotion to avoid requiring `modelscope` when not using emo_text
self.qwen_emo = None
self.qwen_emo_path = os.path.join(self.model_dir, self.cfg.qwen_emo_path)
self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
self.gpt = UnifiedVoice(**self.cfg.gpt)
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
@@ -105,34 +126,15 @@ class IndexTTS2:
if self.use_cuda_kernel:
# preload the CUDA kernel for BigVGAN
try:
from indextts.BigVGAN.alias_free_activation.cuda import load
from indextts.s2mel.modules.bigvgan.alias_free_activation.cuda import activation1d
anti_alias_activation_cuda = load.load()
print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda)
except:
print(">> Preload custom CUDA kernel for BigVGAN", activation1d.anti_alias_activation_cuda)
except Exception as e:
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
print(f"{e!r}")
self.use_cuda_kernel = False
#Prefer local w2v-bert-2.0 if present; otherwise allow HF download
try:
local_w2v_dir = os.path.join(self.model_dir, "w2v-bert-2.0")
if os.path.isdir(local_w2v_dir):
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
local_w2v_dir,
local_files_only=True,
)
print(">> W2V-BERT feature extractor loaded from:", local_w2v_dir)
else:
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
"facebook/w2v-bert-2.0"
)
print(">> W2V-BERT feature extractor loaded from HF repo: facebook/w2v-bert-2.0")
except Exception as e:
# Fallback to HF repo if local load fails unexpectedly
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
"facebook/w2v-bert-2.0"
)
print(">> W2V-BERT feature extractor: local load failed; using HF repo. Error:", e)
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model(
os.path.join(self.model_dir, self.cfg.w2v_stat))
self.semantic_model = self.semantic_model.to(self.device)
@@ -141,28 +143,11 @@ class IndexTTS2:
self.semantic_std = self.semantic_std.to(self.device)
semantic_codec = build_semantic_codec(self.cfg.semantic_codec)
semantic_code_ckpt = os.path.join(self.model_dir, "semantic_codec/model.safetensors")
ckpt_to_load = semantic_code_ckpt
if not os.path.isfile(semantic_code_ckpt):
try:
# Attempt to download into HF cache and use from there
hf_cache_dir = os.path.join(self.model_dir, 'hf_cache')
os.makedirs(hf_cache_dir, exist_ok=True)
ckpt_to_load = hf_hub_download(
repo_id="amphion/MaskGCT",
filename="semantic_codec/model.safetensors",
cache_dir=hf_cache_dir,
local_files_only=False,
)
print(">> semantic_codec weights downloaded to cache:", ckpt_to_load)
except Exception as e:
raise FileNotFoundError(
f"semantic_codec/model.safetensors not found and download failed: {e}"
)
safetensors.torch.load_model(semantic_codec, ckpt_to_load)
semantic_code_ckpt = _hf_download("amphion/MaskGCT", "semantic_codec/model.safetensors")
safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
self.semantic_codec = semantic_codec.to(self.device)
self.semantic_codec.eval()
print('>> semantic_codec weights restored from: {}'.format(ckpt_to_load))
print('>> semantic_codec weights restored from: {}'.format(semantic_code_ckpt))
s2mel_path = os.path.join(self.model_dir, self.cfg.s2mel_checkpoint)
s2mel = MyModel(self.cfg.s2mel, use_gpt_latent=True)
@@ -179,31 +164,14 @@ class IndexTTS2:
self.s2mel.eval()
print(">> s2mel weights restored from:", s2mel_path)
# load campplus_model (local first; fallback to HF cache)
campplus_ckpt_path = os.path.join(self.model_dir, "campplus_cn_common.bin")
campplus_ckpt_to_load = campplus_ckpt_path
if not os.path.isfile(campplus_ckpt_path):
try:
hf_cache_dir = os.path.join(self.model_dir, 'hf_cache')
os.makedirs(hf_cache_dir, exist_ok=True)
campplus_ckpt_to_load = hf_hub_download(
repo_id="funasr/campplus",
filename="campplus_cn_common.bin",
cache_dir=hf_cache_dir,
local_files_only=False,
)
print(">> campplus_model weights downloaded to cache:", campplus_ckpt_to_load)
except Exception as e:
raise FileNotFoundError(
f"campplus_cn_common.bin not found and download failed: {e}"
)
# load campplus_model
campplus_ckpt_path = _hf_download("funasr/campplus", "campplus_cn_common.bin")
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
campplus_model.load_state_dict(torch.load(campplus_ckpt_to_load, map_location="cpu"))
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
self.campplus_model = campplus_model.to(self.device)
self.campplus_model.eval()
print(">> campplus_model weights restored from:", campplus_ckpt_to_load)
print(">> campplus_model weights restored from:", campplus_ckpt_path)
# BigVGAN: prefer local dir if complete; otherwise fallback to HF repo
bigvgan_local_dir = os.path.join(self.model_dir, self.cfg.vocoder.name)
bigvgan_config = os.path.join(bigvgan_local_dir, "config.json")
bigvgan_weights = os.path.join(bigvgan_local_dir, "bigvgan_generator.pt")
@@ -212,15 +180,26 @@ class IndexTTS2:
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_local_dir, use_cuda_kernel=self.use_cuda_kernel)
bigvgan_source = bigvgan_local_dir
else:
# fallback to HF repo id (cached under HF_HUB_CACHE)
repo_id = "nvidia/bigvgan_v2_22khz_80band_256x"
print(">> BigVGAN local files missing or incomplete; loading from HF repo:", repo_id)
self.bigvgan = bigvgan.BigVGAN.from_pretrained(repo_id, use_cuda_kernel=self.use_cuda_kernel)
print('>> BigVGAN local files missing or incomplete; loading from HF repo:', repo_id)
kwargs = {}
if HF_AUTH_TOKEN:
kwargs['token'] = HF_AUTH_TOKEN
try:
self.bigvgan = bigvgan.BigVGAN.from_pretrained(repo_id, use_cuda_kernel=self.use_cuda_kernel, **kwargs)
except HfHubHTTPError as err:
status = getattr(err.response, "status_code", None)
if status == 401:
print(f"[IndexTTS2] BigVGAN download failed with 401; retrying anonymously.")
kwargs = {"token": False}
self.bigvgan = bigvgan.BigVGAN.from_pretrained(repo_id, use_cuda_kernel=self.use_cuda_kernel, **kwargs)
else:
raise
bigvgan_source = repo_id
self.bigvgan = self.bigvgan.to(self.device)
self.bigvgan.remove_weight_norm()
self.bigvgan.eval()
print(">> bigvgan weights restored from:", bigvgan_source)
print('>> bigvgan weights restored from:', bigvgan_source)
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
self.normalizer = TextNormalizer()
@@ -332,6 +311,20 @@ class IndexTTS2:
code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
return codes, code_lens
def interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
"""
Silences to be insert between generated segments.
"""
if not wavs or interval_silence <= 0:
return wavs
# get channel_size
channel_size = wavs[0].size(0)
# get silence tensor
sil_dur = int(sampling_rate * interval_silence / 1000.0)
return torch.zeros(channel_size, sil_dur)
def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
"""
Insert silences between generated segments.
@@ -359,12 +352,67 @@ class IndexTTS2:
if self.gr_progress is not None:
self.gr_progress(value, desc=desc)
def _load_and_cut_audio(self,audio_path,max_audio_length_seconds,verbose=False,sr=None):
if not sr:
audio, sr = librosa.load(audio_path)
else:
audio, _ = librosa.load(audio_path,sr=sr)
audio = torch.tensor(audio).unsqueeze(0)
max_audio_samples = int(max_audio_length_seconds * sr)
if audio.shape[1] > max_audio_samples:
if verbose:
print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples")
audio = audio[:, :max_audio_samples]
return audio, sr
def normalize_emo_vec(self, emo_vector, apply_bias=True):
# apply biased emotion factors for better user experience,
# by de-emphasizing emotions that can cause strange results
if apply_bias:
# [happy, angry, sad, afraid, disgusted, melancholic, surprised, calm]
emo_bias = [0.9375, 0.875, 1.0, 1.0, 0.9375, 0.9375, 0.6875, 0.5625]
emo_vector = [vec * bias for vec, bias in zip(emo_vector, emo_bias)]
# the total emotion sum must be 0.8 or less
emo_sum = sum(emo_vector)
if emo_sum > 0.8:
scale_factor = 0.8 / emo_sum
emo_vector = [vec * scale_factor for vec in emo_vector]
return emo_vector
# 原始推理模式
def infer(self, spk_audio_prompt, text, output_path,
emo_audio_prompt=None, emo_alpha=1.0,
emo_vector=None,
use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
verbose=False, max_text_tokens_per_segment=120, **generation_kwargs):
verbose=False, max_text_tokens_per_segment=120, stream_return=False, more_segment_before=0, **generation_kwargs):
if stream_return:
return self.infer_generator(
spk_audio_prompt, text, output_path,
emo_audio_prompt, emo_alpha,
emo_vector,
use_emo_text, emo_text, use_random, interval_silence,
verbose, max_text_tokens_per_segment, stream_return, more_segment_before, **generation_kwargs
)
else:
try:
return list(self.infer_generator(
spk_audio_prompt, text, output_path,
emo_audio_prompt, emo_alpha,
emo_vector,
use_emo_text, emo_text, use_random, interval_silence,
verbose, max_text_tokens_per_segment, stream_return, more_segment_before, **generation_kwargs
))[0]
except IndexError:
return None
def infer_generator(self, spk_audio_prompt, text, output_path,
emo_audio_prompt=None, emo_alpha=1.0,
emo_vector=None,
use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
verbose=False, max_text_tokens_per_segment=120, stream_return=False, quick_streaming_tokens=0, **generation_kwargs):
print(">> starting inference...")
self._set_gr_progress(0, "starting inference...")
if verbose:
@@ -383,12 +431,6 @@ class IndexTTS2:
# automatically generate emotion vectors from text prompt
if emo_text is None:
emo_text = text # use main text prompt
if self.qwen_emo is None:
if AutoModelForCausalLM is None:
raise ImportError(
"`modelscope` is required to use emo_text. Install `modelscope` or disable 'use_emo_text'."
)
self.qwen_emo = QwenEmotion(self.qwen_emo_path)
emo_dict = self.qwen_emo.inference(emo_text)
print(f"detected emotion vectors from text: {emo_dict}")
# convert ordered dict to list of vectors; the order is VERY important!
@@ -418,10 +460,8 @@ class IndexTTS2:
self.cache_s2mel_style = None
self.cache_s2mel_prompt = None
self.cache_mel = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
audio, sr = librosa.load(spk_audio_prompt)
audio = torch.tensor(audio).unsqueeze(0)
torch.cuda.empty_cache()
audio,sr = self._load_and_cut_audio(spk_audio_prompt,15,verbose)
audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
@@ -474,9 +514,8 @@ class IndexTTS2:
if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt:
if self.cache_emo_cond is not None:
self.cache_emo_cond = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
emo_audio, _ = librosa.load(emo_audio_prompt, sr=16000)
torch.cuda.empty_cache()
emo_audio, _ = self._load_and_cut_audio(emo_audio_prompt,15,verbose,sr=16000)
emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt")
emo_input_features = emo_inputs["input_features"]
emo_attention_mask = emo_inputs["attention_mask"]
@@ -491,14 +530,18 @@ class IndexTTS2:
self._set_gr_progress(0.1, "text processing...")
text_tokens_list = self.tokenizer.tokenize(text)
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
try:
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment, quick_streaming_tokens=quick_streaming_tokens)
except TypeError:
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
segments_count = len(segments)
text_token_ids = self.tokenizer.convert_tokens_to_ids(text_tokens_list)
if self.tokenizer.unk_token_id in text_token_ids:
print(f">> Warning: input text contains {text_token_ids.count(self.tokenizer.unk_token_id)} unknown tokens (id={self.tokenizer.unk_token_id}):")
print(f" Tokens which can't be decoded: {[token for token, token_id in zip(text_tokens_list, text_token_ids) if token_id == self.tokenizer.unk_token_id]}")
print(" Consider updating the BPE model or modifying the text to avoid unknown tokens.")
print(f" >> Warning: input text contains {text_token_ids.count(self.tokenizer.unk_token_id)} unknown tokens (id={self.tokenizer.unk_token_id}):")
print( " Tokens which can't be encoded: ", [t for t, id in zip(text_tokens_list, text_token_ids) if id == self.tokenizer.unk_token_id])
print(f" Consider updating the BPE model or modifying the text to avoid unknown tokens.")
if verbose:
print("text_tokens_list:", text_tokens_list)
print("segments count:", segments_count)
@@ -522,6 +565,7 @@ class IndexTTS2:
s2mel_time = 0
bigvgan_time = 0
has_warned = False
silence = None # for stream_return
for seg_idx, sent in enumerate(segments):
self._set_gr_progress(0.2 + 0.7 * seg_idx / segments_count,
f"speech synthesis {seg_idx + 1}/{segments_count}...")
@@ -557,7 +601,7 @@ class IndexTTS2:
cond_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
emo_cond_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
emo_vec=emovec,
do_sample=do_sample,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
@@ -655,6 +699,11 @@ class IndexTTS2:
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
# wavs.append(wav[:, :-512])
wavs.append(wav.cpu()) # to cpu before saving
if stream_return:
yield wav.cpu()
if silence == None:
silence = self.interval_silence(wavs, sampling_rate=sampling_rate, interval_silence=interval_silence)
yield silence
end_time = time.perf_counter()
self._set_gr_progress(0.9, "saving audio...")
@@ -680,12 +729,16 @@ class IndexTTS2:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
print(">> wav file saved to:", output_path)
return output_path
if stream_return:
return None
yield output_path
else:
if stream_return:
return None
# 返回以符合Gradio的格式要求
wav_data = wav.type(torch.int16)
wav_data = wav_data.numpy().T
return (sampling_rate, wav_data)
yield (sampling_rate, wav_data)
def find_most_similar_cosine(query_vector, matrix):
@@ -816,4 +869,3 @@ if __name__ == "__main__":
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False)
tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)