voice bleeding fix, audio quality, input speakers tags, zero-shot voices

This commit is contained in:
WildAi
2025-09-24 17:42:30 +03:00
parent d04665d073
commit 696ef69152
6 changed files with 260 additions and 247 deletions

View File

@@ -147,8 +147,10 @@ class VibeVoiceProcessor:
def __call__(
self,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
text: Optional[List[str]] = None,
parsed_scripts: Optional[List[List[Tuple[int, str]]]] = None, # <-- ADDED
voice_samples: Optional[List[List[Optional[Union[str, np.ndarray]]]]] = None,
speaker_ids_for_prompt: Optional[List[List[int]]] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
@@ -189,31 +191,26 @@ class VibeVoiceProcessor:
- **speech_masks** -- Speech masks (if voice_samples provided)
- **speech_input_mask** -- Boolean masks indicating speech token positions
"""
# Handle single vs batch input
if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
# Single input
texts = [text]
is_batched = False
else:
# Batch input
texts = text
is_batched = True
# Handle voice samples
if voice_samples is not None:
if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
# Single set of voice samples
voice_samples_list = [voice_samples]
else:
# Batch of voice samples
voice_samples_list = voice_samples
else:
voice_samples_list = [None] * len(texts)
# Process each input
if parsed_scripts is None:
if text is None:
raise ValueError("Either 'text' or 'parsed_scripts' must be provided.")
# Fallback for raw text input (though the node won't use this path)
from ..modules.utils import parse_script_1_based
parsed_scripts = [parse_script_1_based(t)[0] for t in text]
num_scripts = len(parsed_scripts)
voice_samples_list = voice_samples if voice_samples is not None else [[] for _ in range(num_scripts)]
speaker_ids_list = speaker_ids_for_prompt if speaker_ids_for_prompt is not None else [[] for _ in range(num_scripts)]
all_encodings = []
for text_input, voice_input in zip(texts, voice_samples_list):
encoding = self._process_single(text_input, voice_input)
for i in range(num_scripts):
# Pass all three corresponding items to _process_single
encoding = self._process_single(
parsed_scripts[i],
voice_samples_list[i],
speaker_ids_list[i]
)
all_encodings.append(encoding)
# Combine batch
@@ -230,62 +227,38 @@ class VibeVoiceProcessor:
def _process_single(
self,
text: Union[str, TextInput],
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
parsed_script: List[Tuple[int, str]],
voice_samples: List[Optional[Union[str, np.ndarray]]],
speaker_ids: List[int],
) -> Dict[str, Any]:
"""Process a single podcast script."""
# Determine if text is a file path or direct script
script = None
if isinstance(text, str):
# Check if it's a file path
if text.endswith('.json') and os.path.exists(text):
script = self._convert_json_to_script(text)
elif text.endswith('.txt') and os.path.exists(text):
script = self._convert_text_to_script(text)
else:
# Assume it's the script content directly
script = text
if script is None:
raise ValueError(f"Could not process input text: {text}")
# Parse the script
parsed_lines = self._parse_script(script)
all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
# Create system prompt
# system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
system_tokens = self.tokenizer.encode(self.system_prompt)
# Process voice samples if provided
if voice_samples:
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
else:
voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
# Build full token sequence
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(
voice_samples, speaker_ids
)
full_tokens = system_tokens + voice_tokens
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
# Add text input section
full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
for speaker_id, speaker_text in parsed_lines:
speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
full_tokens += speaker_text_tokens
speech_input_mask += [False] * len(speaker_text_tokens)
# Add speech output section
full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
dialogue_lines = []
for speaker_id_0_based, text_chunk in parsed_script:
speaker_id_1_based = speaker_id_0_based + 1
dialogue_lines.append(f"Speaker {speaker_id_1_based}: : {text_chunk}")
full_dialogue_script = "\n".join(dialogue_lines)
final_prompt_text = f" Text input:\n{full_dialogue_script}\n Speech output:\n"
prompt_tokens = self.tokenizer.encode(final_prompt_text, add_special_tokens=False)
full_tokens += prompt_tokens + [self.tokenizer.speech_start_id]
speech_input_mask += [False] * (len(prompt_tokens) + 1)
return {
"input_ids": full_tokens,
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
"speech_input_mask": speech_input_mask,
"parsed_script": parsed_lines,
"all_speakers": all_speakers,
}
def _batch_encode(
@@ -298,11 +271,9 @@ class VibeVoiceProcessor:
return_attention_mask: bool = True,
) -> BatchEncoding:
"""Combine multiple encodings into a batch with padding."""
# Extract input_ids and create attention_mask
input_ids_list = [enc["input_ids"] for enc in encodings]
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
# Determine padding strategy
if isinstance(padding, bool):
padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
elif isinstance(padding, str):
@@ -347,15 +318,11 @@ class VibeVoiceProcessor:
# No padding, just create attention masks
attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
# Process speech inputs
all_speech_inputs = []
has_speech = False
for enc in encodings:
if enc["speech_inputs"] is not None:
if enc.get("speech_inputs"):
all_speech_inputs.extend(enc["speech_inputs"])
has_speech = True
# Prepare batch encoding
batch_encoding = BatchEncoding()
# Handle tensor conversion
@@ -370,79 +337,79 @@ class VibeVoiceProcessor:
batch_encoding["attention_mask"] = attention_masks
batch_encoding["speech_input_mask"] = speech_input_masks_list
# Process speech tensors if present
if has_speech:
speech_dict = self.prepare_speech_inputs(
all_speech_inputs,
return_tensors=return_tensors,
)
if all_speech_inputs:
speech_dict = self.prepare_speech_inputs(all_speech_inputs, return_tensors=return_tensors)
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
else:
batch_encoding["speech_tensors"] = None
batch_encoding["speech_masks"] = None
# Add metadata
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
return batch_encoding
def _create_voice_prompt(
self,
speaker_samples: List[Union[str, np.ndarray]]
speaker_samples: List[Optional[Union[str, np.ndarray]]],
speaker_ids: List[int]
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
"""
Create voice prompt tokens and process audio samples.
This function now handles `None` in the speaker_samples list for zero-shot speakers.
Returns:
tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
"""
if not any(s is not None for s in speaker_samples):
return [], [], []
vae_token_id = self.tokenizer.speech_diffusion_id
voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
voice_speech_inputs = []
voice_speech_masks = [False] * len(voice_full_tokens)
for speaker_id, speaker_audio in enumerate(speaker_samples):
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
for speaker_id, speaker_audio in zip(speaker_ids, speaker_samples):
# Process audio
if isinstance(speaker_audio, str):
# Load audio from file
wav = self.audio_processor._load_audio_from_path(speaker_audio)
if speaker_audio is not None:
logger.info(f"Creating voice prompt for Speaker {speaker_id} from reference audio.")
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
newline_tokens = self.tokenizer.encode('\n', add_special_tokens=False)
if isinstance(speaker_audio, str):
wav = self.audio_processor._load_audio_from_path(speaker_audio)
else:
wav = np.array(speaker_audio, dtype=np.float32)
if self.db_normalize and self.audio_normalizer:
wav = self.audio_normalizer(wav)
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
speaker_tokens = (
prefix_tokens +
[self.tokenizer.speech_start_id] +
[vae_token_id] * vae_tok_len +
[self.tokenizer.speech_end_id] +
newline_tokens
)
vae_input_mask = (
[False] * len(prefix_tokens) +
[False] + # for speech_start_id
[True] * vae_tok_len +
[False] + # for speech_end_id
[False] * len(newline_tokens)
)
voice_speech_inputs.append(wav)
voice_full_tokens.extend(speaker_tokens)
voice_speech_masks.extend(vae_input_mask)
else:
wav = np.array(speaker_audio, dtype=np.float32)
# Apply normalization if needed
if self.db_normalize and self.audio_normalizer:
wav = self.audio_normalizer(wav)
# Calculate token length based on compression ratio
# if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
# vae_tok_len = wav.shape[0]
# else:
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
# Build tokens and masks
speaker_tokens = (prefix_tokens +
[self.tokenizer.speech_start_id] +
[vae_token_id] * vae_tok_len +
[self.tokenizer.speech_end_id] +
self.tokenizer.encode('\n', add_special_tokens=False))
vae_input_mask = ([False] * len(prefix_tokens) +
[False] +
[True] * vae_tok_len +
[False] +
[False])
voice_full_tokens.extend(speaker_tokens)
voice_speech_masks.extend(vae_input_mask)
voice_speech_inputs.append(wav)
logger.info(f"Skipping voice prompt for Speaker {speaker_id} (zero-shot).")
return voice_full_tokens, voice_speech_inputs, voice_speech_masks
def prepare_speech_inputs(
self,
speech_inputs: List[np.ndarray],
@@ -481,10 +448,7 @@ class VibeVoiceProcessor:
padded_speeches[i, :len(speech)] = speech
speech_masks[i, :vae_tok_length] = True
result = {
"padded_speeches": padded_speeches,
"speech_masks": speech_masks,
}
result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks}
# Convert to tensors if requested
if return_tensors == "pt":
@@ -584,12 +548,10 @@ class VibeVoiceProcessor:
parsed_lines = []
speaker_ids = []
# First pass: parse all lines and collect speaker IDs
for line in lines:
if not line.strip():
continue
# Use regex to handle edge cases like multiple colons
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
if match: