diff --git a/indextts/infer_v2.py b/indextts/infer_v2.py index a12105b..cf66d36 100644 --- a/indextts/infer_v2.py +++ b/indextts/infer_v2.py @@ -413,6 +413,13 @@ class IndexTTS2: # 如果参考音频改变了,才需要重新生成, 提升速度 if self.cache_spk_cond is None or self.cache_spk_audio_prompt != spk_audio_prompt: + if self.cache_spk_cond is not None: + self.cache_spk_cond = None + self.cache_s2mel_style = None + self.cache_s2mel_prompt = None + self.cache_mel = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() audio, sr = librosa.load(spk_audio_prompt) audio = torch.tensor(audio).unsqueeze(0) audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio) @@ -465,6 +472,10 @@ class IndexTTS2: emovec_mat = emovec_mat.unsqueeze(0) if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt: + if self.cache_emo_cond is not None: + self.cache_emo_cond = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() emo_audio, _ = librosa.load(emo_audio_prompt, sr=16000) emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt") emo_input_features = emo_inputs["input_features"]