mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-01-26 14:39:45 +00:00
208 lines
11 KiB
Python
208 lines
11 KiB
Python
import torch
|
|
import gc
|
|
import logging
|
|
|
|
import comfy.model_management as model_management
|
|
from comfy.utils import ProgressBar
|
|
|
|
# Import from the dedicated model_info module
|
|
from .modules.model_info import AVAILABLE_VIBEVOICE_MODELS
|
|
from .modules.loader import VibeVoiceModelHandler, ATTENTION_MODES, VIBEVOICE_PATCHER_CACHE, cleanup_old_models
|
|
from .modules.patcher import VibeVoicePatcher
|
|
from .modules.utils import parse_script_1_based, preprocess_comfy_audio, set_vibevoice_seed, check_for_interrupt
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class VibeVoiceTTSNode:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
model_names = list(AVAILABLE_VIBEVOICE_MODELS.keys())
|
|
if not model_names:
|
|
model_names.append("No models found in models/tts/VibeVoice")
|
|
|
|
return {
|
|
"required": {
|
|
"model_name": (model_names, {
|
|
"tooltip": "Select the VibeVoice model to use. Official models will be downloaded automatically."
|
|
}),
|
|
"text": ("STRING", {
|
|
"multiline": True,
|
|
"default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.",
|
|
"tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line."
|
|
}),
|
|
"quantize_llm_4bit": ("BOOLEAN", {
|
|
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision",
|
|
"tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32."
|
|
}),
|
|
"attention_mode": (ATTENTION_MODES, {
|
|
"default": "sdpa",
|
|
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest), Sage (quantized)"
|
|
}),
|
|
"cfg_scale": ("FLOAT", {
|
|
"default": 1.3, "min": 1.0, "max": 10.0, "step": 0.05,
|
|
"tooltip": "Classifier-Free Guidance scale. Higher values increase adherence to the voice prompt but may reduce naturalness. Recommended: 1.3"
|
|
}),
|
|
"inference_steps": ("INT", {
|
|
"default": 10, "min": 1, "max": 50,
|
|
"tooltip": "Number of diffusion steps for audio generation. More steps can improve quality but take longer. Recommended: 10"
|
|
}),
|
|
"seed": ("INT", {
|
|
"default": 42, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "control_after_generate": True,
|
|
"tooltip": "Seed for reproducibility. Set to 0 for a random seed on each run."
|
|
}),
|
|
"do_sample": ("BOOLEAN", {
|
|
"default": True, "label_on": "Enabled (Sampling)", "label_off": "Disabled (Greedy)",
|
|
"tooltip": "Enable to use sampling methods (like temperature and top_p) for more varied output. Disable for deterministic (greedy) decoding."
|
|
}),
|
|
"temperature": ("FLOAT", {
|
|
"default": 0.95, "min": 0.0, "max": 2.0, "step": 0.01,
|
|
"tooltip": "Controls randomness. Higher values make the output more random and creative, while lower values make it more focused and deterministic. Active only if 'do_sample' is enabled."
|
|
}),
|
|
"top_p": ("FLOAT", {
|
|
"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01,
|
|
"tooltip": "Nucleus sampling (Top-P). The model samples from the smallest set of tokens whose cumulative probability exceeds this value. Active only if 'do_sample' is enabled."
|
|
}),
|
|
"top_k": ("INT", {
|
|
"default": 0, "min": 0, "max": 500, "step": 1,
|
|
"tooltip": "Top-K sampling. Restricts sampling to the K most likely next tokens. Set to 0 to disable. Active only if 'do_sample' is enabled."
|
|
}),
|
|
"force_offload": ("BOOLEAN", {
|
|
"default": False, "label_on": "Force Offload", "label_off": "Keep in VRAM",
|
|
"tooltip": "Force model to be offloaded from VRAM after generation. Useful to free up memory between generations but may slow down subsequent runs."
|
|
}),
|
|
},
|
|
"optional": {
|
|
"speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 1' in the script."}),
|
|
"speaker_2_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 2' in the script."}),
|
|
"speaker_3_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 3' in the script."}),
|
|
"speaker_4_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 4' in the script."}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("AUDIO",)
|
|
FUNCTION = "generate_audio"
|
|
CATEGORY = "audio/tts"
|
|
|
|
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, force_offload, **kwargs):
|
|
actual_attention_mode = attention_mode
|
|
if quantize_llm_4bit and attention_mode in ["eager", "flash_attention_2"]:
|
|
actual_attention_mode = "sdpa"
|
|
|
|
cache_key = f"{model_name}_attn_{actual_attention_mode}_q4_{int(quantize_llm_4bit)}"
|
|
|
|
if cache_key not in VIBEVOICE_PATCHER_CACHE:
|
|
cleanup_old_models(keep_cache_key=cache_key)
|
|
|
|
model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit)
|
|
patcher = VibeVoicePatcher(
|
|
model_handler,
|
|
attention_mode=attention_mode,
|
|
load_device=model_management.get_torch_device(),
|
|
offload_device=model_management.unet_offload_device(),
|
|
size=model_handler.size
|
|
)
|
|
VIBEVOICE_PATCHER_CACHE[cache_key] = patcher
|
|
|
|
patcher = VIBEVOICE_PATCHER_CACHE[cache_key]
|
|
model_management.load_model_gpu(patcher)
|
|
model = patcher.model.model
|
|
processor = patcher.model.processor
|
|
|
|
if model is None or processor is None:
|
|
raise RuntimeError("VibeVoice model and processor could not be loaded. Check logs for errors.")
|
|
|
|
parsed_lines_0_based, speaker_ids_1_based = parse_script_1_based(text)
|
|
if not parsed_lines_0_based:
|
|
raise ValueError("Script is empty or invalid. Use 'Speaker 1:', 'Speaker 2:', etc. format.")
|
|
|
|
full_script = "\n".join([f"Speaker {spk+1}: {txt}" for spk, txt in parsed_lines_0_based])
|
|
|
|
speaker_inputs = {i: kwargs.get(f"speaker_{i}_voice") for i in range(1, 5)}
|
|
voice_samples_np = [preprocess_comfy_audio(speaker_inputs[sid]) for sid in speaker_ids_1_based]
|
|
|
|
if any(v is None for v in voice_samples_np):
|
|
missing_ids = [sid for sid, v in zip(speaker_ids_1_based, voice_samples_np) if v is None]
|
|
raise ValueError(f"Script requires voices for Speakers {missing_ids}, but they were not provided.")
|
|
|
|
set_vibevoice_seed(seed)
|
|
|
|
try:
|
|
inputs = processor(
|
|
text=[full_script], voice_samples=[voice_samples_np], padding=True,
|
|
return_tensors="pt", return_attention_mask=True
|
|
)
|
|
|
|
for key, value in inputs.items():
|
|
if isinstance(value, torch.Tensor):
|
|
if torch.any(torch.isnan(value)) or torch.any(torch.isinf(value)):
|
|
logger.error(f"Input tensor '{key}' contains NaN or Inf values")
|
|
raise ValueError(f"Invalid values in input tensor: {key}")
|
|
|
|
inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
|
|
|
model.set_ddpm_inference_steps(num_steps=inference_steps)
|
|
|
|
generation_config = {'do_sample': do_sample}
|
|
if do_sample:
|
|
generation_config['temperature'] = temperature
|
|
generation_config['top_p'] = top_p
|
|
if top_k > 0:
|
|
generation_config['top_k'] = top_k
|
|
|
|
with torch.no_grad():
|
|
pbar = ProgressBar(inference_steps)
|
|
|
|
def progress_callback(step, total_steps):
|
|
pbar.update(1)
|
|
if model_management.interrupt_current_processing:
|
|
raise comfy.model_management.InterruptProcessingException()
|
|
|
|
try:
|
|
outputs = model.generate(
|
|
**inputs, max_new_tokens=None, cfg_scale=cfg_scale,
|
|
tokenizer=processor.tokenizer, generation_config=generation_config,
|
|
verbose=False, stop_check_fn=check_for_interrupt
|
|
)
|
|
pbar.update(inference_steps - pbar.current)
|
|
|
|
except RuntimeError as e:
|
|
error_msg = str(e).lower()
|
|
if "assertion" in error_msg or "cuda" in error_msg:
|
|
logger.error(f"CUDA assertion failed with {attention_mode} attention: {e}")
|
|
logger.error("This might be due to invalid input data, GPU memory issues, or incompatible attention mode.")
|
|
logger.error("Try restarting ComfyUI, using different audio files, or switching to 'eager' attention mode.")
|
|
raise e
|
|
except comfy.model_management.InterruptProcessingException:
|
|
logger.info("VibeVoice generation interrupted by user")
|
|
raise
|
|
finally:
|
|
pbar.update_absolute(inference_steps)
|
|
|
|
except comfy.model_management.InterruptProcessingException:
|
|
logger.info("VibeVoice TTS generation was cancelled")
|
|
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during VibeVoice generation with {attention_mode} attention: {e}")
|
|
if "interrupt" in str(e).lower() or "cancel" in str(e).lower():
|
|
logger.info("Generation was interrupted")
|
|
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
|
raise
|
|
|
|
output_waveform = outputs.speech_outputs[0]
|
|
if output_waveform.ndim == 1: output_waveform = output_waveform.unsqueeze(0)
|
|
if output_waveform.ndim == 2: output_waveform = output_waveform.unsqueeze(0)
|
|
|
|
if force_offload:
|
|
logger.info(f"Force offloading VibeVoice model '{model_name}' from VRAM...")
|
|
if patcher.is_loaded:
|
|
patcher.unpatch_model(unpatch_weights=True)
|
|
model_management.unload_all_models()
|
|
gc.collect()
|
|
model_management.soft_empty_cache()
|
|
logger.info("Model force offload completed")
|
|
|
|
return ({"waveform": output_waveform.detach().cpu(), "sample_rate": 24000},)
|
|
|
|
NODE_CLASS_MAPPINGS = {"VibeVoiceTTS": VibeVoiceTTSNode}
|
|
NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"} |