Advanced node + generation seed

This commit is contained in:
snicolast
2025-09-22 12:51:09 +12:00
parent fe1098ed8a
commit 13887983b1
5 changed files with 365 additions and 2 deletions

View File

@@ -496,6 +496,7 @@ class IndexTTS2:
num_beams = generation_kwargs.pop("num_beams", 3)
repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 1500)
speech_speed = float(generation_kwargs.pop("speech_speed", 1.0))
sampling_rate = 22050
wavs = []
@@ -539,7 +540,7 @@ class IndexTTS2:
cond_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device),
emo_cond_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device),
emo_vec=emovec,
do_sample=True,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
@@ -610,7 +611,8 @@ class IndexTTS2:
S_infer = self.semantic_codec.quantizer.vq2emb(codes.unsqueeze(1))
S_infer = S_infer.transpose(1, 2)
S_infer = S_infer + latent
target_lengths = (code_lens * 1.72).long()
speed_scale = max(0.1, min(3.0, float(speech_speed)))
target_lengths = torch.clamp((code_lens.float() * 1.72 * speed_scale).long(), min=1)
cond = self.s2mel.models['length_regulator'](S_infer,
ylens=target_lengths,