From 00d23535844d88bf5fa7de6dc45c06f3eceba6ac Mon Sep 17 00:00:00 2001 From: snicolast Date: Tue, 30 Sep 2025 10:23:03 +1300 Subject: [PATCH] Feat: Warn if input text contains UNK tokens (original repo) --- indextts/infer_v2.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/indextts/infer_v2.py b/indextts/infer_v2.py index cf66d36..23808a9 100644 --- a/indextts/infer_v2.py +++ b/indextts/infer_v2.py @@ -493,6 +493,12 @@ class IndexTTS2: text_tokens_list = self.tokenizer.tokenize(text) segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment) segments_count = len(segments) + text_token_ids = self.tokenizer.convert_tokens_to_ids(text_tokens_list) + if self.tokenizer.unk_token_id in text_token_ids: + print(f">> Warning: input text contains {text_token_ids.count(self.tokenizer.unk_token_id)} unknown tokens (id={self.tokenizer.unk_token_id}):") + print(f" Tokens which can't be decoded: {[token for token, token_id in zip(text_tokens_list, text_token_ids) if token_id == self.tokenizer.unk_token_id]}") + print(" Consider updating the BPE model or modifying the text to avoid unknown tokens.") + if verbose: print("text_tokens_list:", text_tokens_list) print("segments count:", segments_count) @@ -810,3 +816,4 @@ if __name__ == "__main__": tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False) tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True) +