mirror of
https://github.com/snicolast/ComfyUI-IndexTTS2.git
synced 2026-04-30 11:41:33 +00:00
Feat: Warn if input text contains UNK tokens (original repo)
This commit is contained in:
@@ -493,6 +493,12 @@ class IndexTTS2:
|
|||||||
text_tokens_list = self.tokenizer.tokenize(text)
|
text_tokens_list = self.tokenizer.tokenize(text)
|
||||||
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
|
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
|
||||||
segments_count = len(segments)
|
segments_count = len(segments)
|
||||||
|
text_token_ids = self.tokenizer.convert_tokens_to_ids(text_tokens_list)
|
||||||
|
if self.tokenizer.unk_token_id in text_token_ids:
|
||||||
|
print(f">> Warning: input text contains {text_token_ids.count(self.tokenizer.unk_token_id)} unknown tokens (id={self.tokenizer.unk_token_id}):")
|
||||||
|
print(f" Tokens which can't be decoded: {[token for token, token_id in zip(text_tokens_list, text_token_ids) if token_id == self.tokenizer.unk_token_id]}")
|
||||||
|
print(" Consider updating the BPE model or modifying the text to avoid unknown tokens.")
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("text_tokens_list:", text_tokens_list)
|
print("text_tokens_list:", text_tokens_list)
|
||||||
print("segments count:", segments_count)
|
print("segments count:", segments_count)
|
||||||
@@ -810,3 +816,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False)
|
tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False)
|
||||||
tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
|
tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user