mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-04-30 19:51:37 +00:00
voice bleeding fix, audio quality, input speakers tags, zero-shot voices
This commit is contained in:
@@ -147,8 +147,10 @@ class VibeVoiceProcessor:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
||||
voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
|
||||
text: Optional[List[str]] = None,
|
||||
parsed_scripts: Optional[List[List[Tuple[int, str]]]] = None, # <-- ADDED
|
||||
voice_samples: Optional[List[List[Optional[Union[str, np.ndarray]]]]] = None,
|
||||
speaker_ids_for_prompt: Optional[List[List[int]]] = None,
|
||||
padding: Union[bool, str, PaddingStrategy] = True,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
@@ -189,31 +191,26 @@ class VibeVoiceProcessor:
|
||||
- **speech_masks** -- Speech masks (if voice_samples provided)
|
||||
- **speech_input_mask** -- Boolean masks indicating speech token positions
|
||||
"""
|
||||
# Handle single vs batch input
|
||||
if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
|
||||
# Single input
|
||||
texts = [text]
|
||||
is_batched = False
|
||||
else:
|
||||
# Batch input
|
||||
texts = text
|
||||
is_batched = True
|
||||
|
||||
# Handle voice samples
|
||||
if voice_samples is not None:
|
||||
if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
|
||||
# Single set of voice samples
|
||||
voice_samples_list = [voice_samples]
|
||||
else:
|
||||
# Batch of voice samples
|
||||
voice_samples_list = voice_samples
|
||||
else:
|
||||
voice_samples_list = [None] * len(texts)
|
||||
|
||||
# Process each input
|
||||
if parsed_scripts is None:
|
||||
if text is None:
|
||||
raise ValueError("Either 'text' or 'parsed_scripts' must be provided.")
|
||||
# Fallback for raw text input (though the node won't use this path)
|
||||
from ..modules.utils import parse_script_1_based
|
||||
parsed_scripts = [parse_script_1_based(t)[0] for t in text]
|
||||
|
||||
num_scripts = len(parsed_scripts)
|
||||
voice_samples_list = voice_samples if voice_samples is not None else [[] for _ in range(num_scripts)]
|
||||
speaker_ids_list = speaker_ids_for_prompt if speaker_ids_for_prompt is not None else [[] for _ in range(num_scripts)]
|
||||
|
||||
all_encodings = []
|
||||
for text_input, voice_input in zip(texts, voice_samples_list):
|
||||
encoding = self._process_single(text_input, voice_input)
|
||||
for i in range(num_scripts):
|
||||
# Pass all three corresponding items to _process_single
|
||||
encoding = self._process_single(
|
||||
parsed_scripts[i],
|
||||
voice_samples_list[i],
|
||||
speaker_ids_list[i]
|
||||
)
|
||||
all_encodings.append(encoding)
|
||||
|
||||
# Combine batch
|
||||
@@ -230,62 +227,38 @@ class VibeVoiceProcessor:
|
||||
|
||||
def _process_single(
|
||||
self,
|
||||
text: Union[str, TextInput],
|
||||
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
|
||||
parsed_script: List[Tuple[int, str]],
|
||||
voice_samples: List[Optional[Union[str, np.ndarray]]],
|
||||
speaker_ids: List[int],
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a single podcast script."""
|
||||
# Determine if text is a file path or direct script
|
||||
script = None
|
||||
if isinstance(text, str):
|
||||
# Check if it's a file path
|
||||
if text.endswith('.json') and os.path.exists(text):
|
||||
script = self._convert_json_to_script(text)
|
||||
elif text.endswith('.txt') and os.path.exists(text):
|
||||
script = self._convert_text_to_script(text)
|
||||
else:
|
||||
# Assume it's the script content directly
|
||||
script = text
|
||||
|
||||
if script is None:
|
||||
raise ValueError(f"Could not process input text: {text}")
|
||||
|
||||
# Parse the script
|
||||
parsed_lines = self._parse_script(script)
|
||||
all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
|
||||
|
||||
# Create system prompt
|
||||
# system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
|
||||
|
||||
system_tokens = self.tokenizer.encode(self.system_prompt)
|
||||
|
||||
# Process voice samples if provided
|
||||
if voice_samples:
|
||||
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
|
||||
else:
|
||||
voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
|
||||
|
||||
# Build full token sequence
|
||||
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(
|
||||
voice_samples, speaker_ids
|
||||
)
|
||||
|
||||
full_tokens = system_tokens + voice_tokens
|
||||
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
|
||||
|
||||
# Add text input section
|
||||
full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
|
||||
speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
|
||||
|
||||
for speaker_id, speaker_text in parsed_lines:
|
||||
speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
|
||||
full_tokens += speaker_text_tokens
|
||||
speech_input_mask += [False] * len(speaker_text_tokens)
|
||||
|
||||
# Add speech output section
|
||||
full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
|
||||
speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
|
||||
dialogue_lines = []
|
||||
for speaker_id_0_based, text_chunk in parsed_script:
|
||||
speaker_id_1_based = speaker_id_0_based + 1
|
||||
dialogue_lines.append(f"Speaker {speaker_id_1_based}: : {text_chunk}")
|
||||
|
||||
full_dialogue_script = "\n".join(dialogue_lines)
|
||||
|
||||
final_prompt_text = f" Text input:\n{full_dialogue_script}\n Speech output:\n"
|
||||
|
||||
prompt_tokens = self.tokenizer.encode(final_prompt_text, add_special_tokens=False)
|
||||
|
||||
full_tokens += prompt_tokens + [self.tokenizer.speech_start_id]
|
||||
speech_input_mask += [False] * (len(prompt_tokens) + 1)
|
||||
|
||||
return {
|
||||
"input_ids": full_tokens,
|
||||
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
|
||||
"speech_input_mask": speech_input_mask,
|
||||
"parsed_script": parsed_lines,
|
||||
"all_speakers": all_speakers,
|
||||
}
|
||||
|
||||
def _batch_encode(
|
||||
@@ -298,11 +271,9 @@ class VibeVoiceProcessor:
|
||||
return_attention_mask: bool = True,
|
||||
) -> BatchEncoding:
|
||||
"""Combine multiple encodings into a batch with padding."""
|
||||
# Extract input_ids and create attention_mask
|
||||
input_ids_list = [enc["input_ids"] for enc in encodings]
|
||||
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
|
||||
|
||||
# Determine padding strategy
|
||||
if isinstance(padding, bool):
|
||||
padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
|
||||
elif isinstance(padding, str):
|
||||
@@ -347,15 +318,11 @@ class VibeVoiceProcessor:
|
||||
# No padding, just create attention masks
|
||||
attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
|
||||
|
||||
# Process speech inputs
|
||||
all_speech_inputs = []
|
||||
has_speech = False
|
||||
for enc in encodings:
|
||||
if enc["speech_inputs"] is not None:
|
||||
if enc.get("speech_inputs"):
|
||||
all_speech_inputs.extend(enc["speech_inputs"])
|
||||
has_speech = True
|
||||
|
||||
# Prepare batch encoding
|
||||
|
||||
batch_encoding = BatchEncoding()
|
||||
|
||||
# Handle tensor conversion
|
||||
@@ -370,79 +337,79 @@ class VibeVoiceProcessor:
|
||||
batch_encoding["attention_mask"] = attention_masks
|
||||
batch_encoding["speech_input_mask"] = speech_input_masks_list
|
||||
|
||||
# Process speech tensors if present
|
||||
if has_speech:
|
||||
speech_dict = self.prepare_speech_inputs(
|
||||
all_speech_inputs,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
if all_speech_inputs:
|
||||
speech_dict = self.prepare_speech_inputs(all_speech_inputs, return_tensors=return_tensors)
|
||||
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
|
||||
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
|
||||
else:
|
||||
batch_encoding["speech_tensors"] = None
|
||||
batch_encoding["speech_masks"] = None
|
||||
|
||||
# Add metadata
|
||||
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
|
||||
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
|
||||
|
||||
return batch_encoding
|
||||
|
||||
def _create_voice_prompt(
|
||||
self,
|
||||
speaker_samples: List[Union[str, np.ndarray]]
|
||||
speaker_samples: List[Optional[Union[str, np.ndarray]]],
|
||||
speaker_ids: List[int]
|
||||
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
|
||||
"""
|
||||
Create voice prompt tokens and process audio samples.
|
||||
This function now handles `None` in the speaker_samples list for zero-shot speakers.
|
||||
|
||||
Returns:
|
||||
tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
|
||||
"""
|
||||
if not any(s is not None for s in speaker_samples):
|
||||
return [], [], []
|
||||
|
||||
vae_token_id = self.tokenizer.speech_diffusion_id
|
||||
|
||||
voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
|
||||
voice_speech_inputs = []
|
||||
voice_speech_masks = [False] * len(voice_full_tokens)
|
||||
|
||||
for speaker_id, speaker_audio in enumerate(speaker_samples):
|
||||
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
|
||||
for speaker_id, speaker_audio in zip(speaker_ids, speaker_samples):
|
||||
|
||||
# Process audio
|
||||
if isinstance(speaker_audio, str):
|
||||
# Load audio from file
|
||||
wav = self.audio_processor._load_audio_from_path(speaker_audio)
|
||||
if speaker_audio is not None:
|
||||
logger.info(f"Creating voice prompt for Speaker {speaker_id} from reference audio.")
|
||||
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
|
||||
newline_tokens = self.tokenizer.encode('\n', add_special_tokens=False)
|
||||
|
||||
if isinstance(speaker_audio, str):
|
||||
wav = self.audio_processor._load_audio_from_path(speaker_audio)
|
||||
else:
|
||||
wav = np.array(speaker_audio, dtype=np.float32)
|
||||
|
||||
if self.db_normalize and self.audio_normalizer:
|
||||
wav = self.audio_normalizer(wav)
|
||||
|
||||
|
||||
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
|
||||
speaker_tokens = (
|
||||
prefix_tokens +
|
||||
[self.tokenizer.speech_start_id] +
|
||||
[vae_token_id] * vae_tok_len +
|
||||
[self.tokenizer.speech_end_id] +
|
||||
newline_tokens
|
||||
)
|
||||
|
||||
vae_input_mask = (
|
||||
[False] * len(prefix_tokens) +
|
||||
[False] + # for speech_start_id
|
||||
[True] * vae_tok_len +
|
||||
[False] + # for speech_end_id
|
||||
[False] * len(newline_tokens)
|
||||
)
|
||||
voice_speech_inputs.append(wav)
|
||||
voice_full_tokens.extend(speaker_tokens)
|
||||
voice_speech_masks.extend(vae_input_mask)
|
||||
else:
|
||||
wav = np.array(speaker_audio, dtype=np.float32)
|
||||
|
||||
# Apply normalization if needed
|
||||
if self.db_normalize and self.audio_normalizer:
|
||||
wav = self.audio_normalizer(wav)
|
||||
|
||||
# Calculate token length based on compression ratio
|
||||
# if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
|
||||
# vae_tok_len = wav.shape[0]
|
||||
# else:
|
||||
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
|
||||
|
||||
# Build tokens and masks
|
||||
speaker_tokens = (prefix_tokens +
|
||||
[self.tokenizer.speech_start_id] +
|
||||
[vae_token_id] * vae_tok_len +
|
||||
[self.tokenizer.speech_end_id] +
|
||||
self.tokenizer.encode('\n', add_special_tokens=False))
|
||||
|
||||
vae_input_mask = ([False] * len(prefix_tokens) +
|
||||
[False] +
|
||||
[True] * vae_tok_len +
|
||||
[False] +
|
||||
[False])
|
||||
|
||||
voice_full_tokens.extend(speaker_tokens)
|
||||
voice_speech_masks.extend(vae_input_mask)
|
||||
voice_speech_inputs.append(wav)
|
||||
|
||||
logger.info(f"Skipping voice prompt for Speaker {speaker_id} (zero-shot).")
|
||||
|
||||
|
||||
return voice_full_tokens, voice_speech_inputs, voice_speech_masks
|
||||
|
||||
|
||||
def prepare_speech_inputs(
|
||||
self,
|
||||
speech_inputs: List[np.ndarray],
|
||||
@@ -481,10 +448,7 @@ class VibeVoiceProcessor:
|
||||
padded_speeches[i, :len(speech)] = speech
|
||||
speech_masks[i, :vae_tok_length] = True
|
||||
|
||||
result = {
|
||||
"padded_speeches": padded_speeches,
|
||||
"speech_masks": speech_masks,
|
||||
}
|
||||
result = {"padded_speeches": padded_speeches, "speech_masks": speech_masks}
|
||||
|
||||
# Convert to tensors if requested
|
||||
if return_tensors == "pt":
|
||||
@@ -584,12 +548,10 @@ class VibeVoiceProcessor:
|
||||
parsed_lines = []
|
||||
speaker_ids = []
|
||||
|
||||
# First pass: parse all lines and collect speaker IDs
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Use regex to handle edge cases like multiple colons
|
||||
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
|
||||
|
||||
if match:
|
||||
|
||||
Reference in New Issue
Block a user