mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-05-01 04:01:37 +00:00
major refactoring
This commit is contained in:
83
__init__.py
83
__init__.py
@@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
|
import folder_paths
|
||||||
|
import json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import sageattention
|
import sageattention
|
||||||
@@ -12,34 +14,95 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|||||||
if current_dir not in sys.path:
|
if current_dir not in sys.path:
|
||||||
sys.path.append(current_dir)
|
sys.path.append(current_dir)
|
||||||
|
|
||||||
import folder_paths
|
from .modules.model_info import AVAILABLE_VIBEVOICE_MODELS, MODEL_CONFIGS
|
||||||
|
|
||||||
from .vibevoice_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
# Configure a logger
|
||||||
|
|
||||||
# Configure a logger for the entire custom node package
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
if not logger.hasHandlers():
|
if not logger.hasHandlers():
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
formatter = logging.Formatter(f"[ComfyUI-VibeVoice] %(message)s")
|
formatter = logging.Formatter(f"[ComfyUI-VibeVoice] %(message)s")
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
# This is just the *name* of the subdirectory, not the full path.
|
||||||
|
VIBEVOICE_SUBDIR_NAME = "VibeVoice"
|
||||||
|
|
||||||
VIBEVOICE_MODEL_SUBDIR = os.path.join("tts", "VibeVoice")
|
# This is the *primary* path where official models will be downloaded.
|
||||||
|
primary_vibevoice_models_path = os.path.join(folder_paths.models_dir, "tts", VIBEVOICE_SUBDIR_NAME)
|
||||||
|
os.makedirs(primary_vibevoice_models_path, exist_ok=True)
|
||||||
|
|
||||||
vibevoice_models_full_path = os.path.join(folder_paths.models_dir, VIBEVOICE_MODEL_SUBDIR)
|
# Register the tts path type with ComfyUI so get_folder_paths works
|
||||||
os.makedirs(vibevoice_models_full_path, exist_ok=True)
|
|
||||||
|
|
||||||
# Register the tts/VibeVoice path with ComfyUI
|
|
||||||
tts_path = os.path.join(folder_paths.models_dir, "tts")
|
tts_path = os.path.join(folder_paths.models_dir, "tts")
|
||||||
if "tts" not in folder_paths.folder_names_and_paths:
|
if "tts" not in folder_paths.folder_names_and_paths:
|
||||||
supported_exts = folder_paths.supported_pt_extensions.union({".safetensors", ".json"})
|
supported_exts = folder_paths.supported_pt_extensions.union({".safetensors", ".json"})
|
||||||
folder_paths.folder_names_and_paths["tts"] = ([tts_path], supported_exts)
|
folder_paths.folder_names_and_paths["tts"] = ([tts_path], supported_exts)
|
||||||
else:
|
else:
|
||||||
|
# Ensure the default path is in the list if it's not already
|
||||||
if tts_path not in folder_paths.folder_names_and_paths["tts"][0]:
|
if tts_path not in folder_paths.folder_names_and_paths["tts"][0]:
|
||||||
folder_paths.folder_names_and_paths["tts"][0].append(tts_path)
|
folder_paths.folder_names_and_paths["tts"][0].append(tts_path)
|
||||||
|
|
||||||
|
# The logic for dynamic model discovery
|
||||||
|
# ToDo: optimize finding
|
||||||
|
|
||||||
|
# official models that can be auto-downloaded
|
||||||
|
for model_name, config in MODEL_CONFIGS.items():
|
||||||
|
AVAILABLE_VIBEVOICE_MODELS[model_name] = {
|
||||||
|
"type": "official",
|
||||||
|
"repo_id": config["repo_id"],
|
||||||
|
"tokenizer_repo": "Qwen/Qwen2.5-7B" if "Large" in model_name else "Qwen/Qwen2.5-1.5B"
|
||||||
|
}
|
||||||
|
|
||||||
|
# just workaround, default + custom
|
||||||
|
vibevoice_search_paths = []
|
||||||
|
# Use ComfyUI's API to get all registered 'tts' folders
|
||||||
|
for tts_folder in folder_paths.get_folder_paths("tts"):
|
||||||
|
potential_path = os.path.join(tts_folder, VIBEVOICE_SUBDIR_NAME)
|
||||||
|
if os.path.isdir(potential_path) and potential_path not in vibevoice_search_paths:
|
||||||
|
vibevoice_search_paths.append(potential_path)
|
||||||
|
|
||||||
|
# Add the primary path just in case it wasn't registered for some reason
|
||||||
|
if primary_vibevoice_models_path not in vibevoice_search_paths:
|
||||||
|
vibevoice_search_paths.insert(0, primary_vibevoice_models_path)
|
||||||
|
|
||||||
|
# Messy... Discover all local models in the search paths
|
||||||
|
for search_path in vibevoice_search_paths:
|
||||||
|
logger.info(f"Scanning for VibeVoice models in: {search_path}")
|
||||||
|
if not os.path.exists(search_path): continue
|
||||||
|
for item in os.listdir(search_path):
|
||||||
|
item_path = os.path.join(search_path, item)
|
||||||
|
|
||||||
|
# Case 1: we have a standard HF directory
|
||||||
|
if os.path.isdir(item_path):
|
||||||
|
model_name = item
|
||||||
|
if model_name in AVAILABLE_VIBEVOICE_MODELS: continue
|
||||||
|
|
||||||
|
config_exists = os.path.exists(os.path.join(item_path, "config.json"))
|
||||||
|
weights_exist = os.path.exists(os.path.join(item_path, "model.safetensors.index.json")) or any(f.endswith(('.safetensors', '.bin')) for f in os.listdir(item_path))
|
||||||
|
|
||||||
|
if config_exists and weights_exist:
|
||||||
|
tokenizer_repo = "Qwen/Qwen2.5-7B" if "large" in model_name.lower() else "Qwen/Qwen2.5-1.5B"
|
||||||
|
AVAILABLE_VIBEVOICE_MODELS[model_name] = {
|
||||||
|
"type": "local_dir",
|
||||||
|
"path": item_path,
|
||||||
|
"tokenizer_repo": tokenizer_repo
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case 2: Item is a standalone file
|
||||||
|
elif os.path.isfile(item_path) and any(item.endswith(ext) for ext in folder_paths.supported_pt_extensions):
|
||||||
|
model_name = os.path.splitext(item)[0]
|
||||||
|
if model_name in AVAILABLE_VIBEVOICE_MODELS: continue
|
||||||
|
|
||||||
|
tokenizer_repo = "Qwen/Qwen2.5-7B" if "large" in model_name.lower() else "Qwen/Qwen2.5-1.5B"
|
||||||
|
AVAILABLE_VIBEVOICE_MODELS[model_name] = {
|
||||||
|
"type": "standalone",
|
||||||
|
"path": item_path,
|
||||||
|
"tokenizer_repo": tokenizer_repo
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Discovered VibeVoice models: {sorted(list(AVAILABLE_VIBEVOICE_MODELS.keys()))}")
|
||||||
|
|
||||||
|
from .vibevoice_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
|
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
|
||||||
0
modules/__init__.py
Normal file
0
modules/__init__.py
Normal file
212
modules/loader.py
Normal file
212
modules/loader.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import comfy.model_management as model_management
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
_transformers_version = version.parse(transformers.__version__)
|
||||||
|
_DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0")
|
||||||
|
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
from ..vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
|
||||||
|
from ..vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
||||||
|
from ..vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
||||||
|
from ..vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
||||||
|
from ..vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast
|
||||||
|
|
||||||
|
from .model_info import AVAILABLE_VIBEVOICE_MODELS, MODEL_CONFIGS
|
||||||
|
from .. import SAGE_ATTENTION_AVAILABLE
|
||||||
|
if SAGE_ATTENTION_AVAILABLE:
|
||||||
|
from ..vibevoice.modular.sage_attention_patch import set_sage_attention
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
LOADED_MODELS = {}
|
||||||
|
VIBEVOICE_PATCHER_CACHE = {}
|
||||||
|
|
||||||
|
ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"]
|
||||||
|
if SAGE_ATTENTION_AVAILABLE:
|
||||||
|
ATTENTION_MODES.append("sage")
|
||||||
|
|
||||||
|
def cleanup_old_models(keep_cache_key=None):
|
||||||
|
global LOADED_MODELS, VIBEVOICE_PATCHER_CACHE
|
||||||
|
keys_to_remove = []
|
||||||
|
for key in list(LOADED_MODELS.keys()):
|
||||||
|
if key != keep_cache_key:
|
||||||
|
keys_to_remove.append(key)
|
||||||
|
del LOADED_MODELS[key]
|
||||||
|
for key in list(VIBEVOICE_PATCHER_CACHE.keys()):
|
||||||
|
if key != keep_cache_key:
|
||||||
|
try:
|
||||||
|
patcher = VIBEVOICE_PATCHER_CACHE[key]
|
||||||
|
if hasattr(patcher, 'model') and patcher.model:
|
||||||
|
patcher.model.model = None
|
||||||
|
patcher.model.processor = None
|
||||||
|
del VIBEVOICE_PATCHER_CACHE[key]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning up patcher {key}: {e}")
|
||||||
|
if keys_to_remove:
|
||||||
|
logger.info(f"Cleaned up cached models: {keys_to_remove}")
|
||||||
|
gc.collect()
|
||||||
|
model_management.soft_empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
class VibeVoiceModelHandler(torch.nn.Module):
|
||||||
|
def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False):
|
||||||
|
super().__init__()
|
||||||
|
self.model_pack_name = model_pack_name
|
||||||
|
self.attention_mode = attention_mode
|
||||||
|
self.use_llm_4bit = use_llm_4bit
|
||||||
|
self.cache_key = f"{self.model_pack_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}"
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
info = AVAILABLE_VIBEVOICE_MODELS.get(model_pack_name, {})
|
||||||
|
size_gb = MODEL_CONFIGS.get(model_pack_name, {}).get("size_gb", 4.0)
|
||||||
|
self.size = int(size_gb * (1024**3))
|
||||||
|
def load_model(self, device, attention_mode="eager"):
|
||||||
|
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit)
|
||||||
|
if self.model.device != device:
|
||||||
|
self.model.to(device)
|
||||||
|
|
||||||
|
class VibeVoiceLoader:
|
||||||
|
@staticmethod
|
||||||
|
def _check_gpu_for_sage_attention():
|
||||||
|
if not SAGE_ATTENTION_AVAILABLE: return False
|
||||||
|
if not torch.cuda.is_available(): return False
|
||||||
|
major, _ = torch.cuda.get_device_capability()
|
||||||
|
if major < 8:
|
||||||
|
logger.warning(f"Your GPU (compute capability {major}.x) does not support SageAttention, which requires CC 8.0+. Sage option will be disabled.")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False):
|
||||||
|
if model_name not in AVAILABLE_VIBEVOICE_MODELS:
|
||||||
|
raise ValueError(f"Unknown VibeVoice model: {model_name}. Available models: {list(AVAILABLE_VIBEVOICE_MODELS.keys())}")
|
||||||
|
|
||||||
|
if use_llm_4bit and attention_mode in ["eager", "flash_attention_2"]:
|
||||||
|
logger.warning(f"Attention mode '{attention_mode}' is not recommended with 4-bit quantization. Falling back to 'sdpa' for stability and performance.")
|
||||||
|
attention_mode = "sdpa"
|
||||||
|
if attention_mode not in ATTENTION_MODES:
|
||||||
|
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
|
||||||
|
attention_mode = "eager"
|
||||||
|
|
||||||
|
cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}"
|
||||||
|
if cache_key in LOADED_MODELS:
|
||||||
|
logger.info(f"Using cached model with {attention_mode} attention and q4={use_llm_4bit}")
|
||||||
|
return LOADED_MODELS[cache_key]
|
||||||
|
|
||||||
|
model_info = AVAILABLE_VIBEVOICE_MODELS[model_name]
|
||||||
|
model_type = model_info["type"]
|
||||||
|
vibevoice_base_path = os.path.join(folder_paths.get_folder_paths("tts")[0], "VibeVoice")
|
||||||
|
|
||||||
|
model_path_or_none = None
|
||||||
|
config_path = None
|
||||||
|
preprocessor_config_path = None
|
||||||
|
tokenizer_dir = None
|
||||||
|
|
||||||
|
if model_type == "official":
|
||||||
|
model_path_or_none = os.path.join(vibevoice_base_path, model_name)
|
||||||
|
if not os.path.exists(os.path.join(model_path_or_none, "model.safetensors.index.json")):
|
||||||
|
logger.info(f"Downloading official VibeVoice model: {model_name}...")
|
||||||
|
snapshot_download(repo_id=model_info["repo_id"], local_dir=model_path_or_none, local_dir_use_symlinks=False)
|
||||||
|
config_path = os.path.join(model_path_or_none, "config.json")
|
||||||
|
preprocessor_config_path = os.path.join(model_path_or_none, "preprocessor_config.json")
|
||||||
|
tokenizer_dir = model_path_or_none
|
||||||
|
elif model_type == "local_dir":
|
||||||
|
model_path_or_none = model_info["path"]
|
||||||
|
config_path = os.path.join(model_path_or_none, "config.json")
|
||||||
|
preprocessor_config_path = os.path.join(model_path_or_none, "preprocessor_config.json")
|
||||||
|
tokenizer_dir = model_path_or_none
|
||||||
|
elif model_type == "standalone":
|
||||||
|
model_path_or_none = None # IMPORTANT: This must be None when loading from state_dict
|
||||||
|
config_path = os.path.splitext(model_info["path"])[0] + ".config.json"
|
||||||
|
preprocessor_config_path = os.path.splitext(model_info["path"])[0] + ".preprocessor.json"
|
||||||
|
tokenizer_dir = os.path.dirname(model_info["path"])
|
||||||
|
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
config = VibeVoiceConfig.from_pretrained(config_path)
|
||||||
|
else:
|
||||||
|
fallback_name = "default_VibeVoice-Large_config.json" if "large" in model_name.lower() else "default_VibeVoice-1.5B_config.json"
|
||||||
|
fallback_path = os.path.join(os.path.dirname(__file__), "..", "vibevoice", "configs", fallback_name)
|
||||||
|
logger.warning(f"Config not found for '{model_name}'. Using fallback: {fallback_name}")
|
||||||
|
config = VibeVoiceConfig.from_pretrained(fallback_path)
|
||||||
|
|
||||||
|
# Processor & Tokenizer setup
|
||||||
|
tokenizer_repo = model_info["tokenizer_repo"]
|
||||||
|
tokenizer_file_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||||
|
if not os.path.exists(tokenizer_file_path):
|
||||||
|
logger.info(f"tokenizer.json not found. Downloading from '{tokenizer_repo}'...")
|
||||||
|
hf_hub_download(repo_id=tokenizer_repo, filename="tokenizer.json", local_dir=tokenizer_dir, local_dir_use_symlinks=False)
|
||||||
|
vibevoice_tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file_path)
|
||||||
|
|
||||||
|
processor_config_data = {}
|
||||||
|
if os.path.exists(preprocessor_config_path):
|
||||||
|
with open(preprocessor_config_path, 'r', encoding='utf-8') as f: processor_config_data = json.load(f)
|
||||||
|
|
||||||
|
audio_processor = VibeVoiceTokenizerProcessor()
|
||||||
|
processor = VibeVoiceProcessor(tokenizer=vibevoice_tokenizer, audio_processor=audio_processor, speech_tok_compress_ratio=processor_config_data.get("speech_tok_compress_ratio", 3200), db_normalize=processor_config_data.get("db_normalize", True))
|
||||||
|
|
||||||
|
# Model Loading Prep
|
||||||
|
if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): model_dtype = torch.bfloat16
|
||||||
|
else: model_dtype = torch.float16
|
||||||
|
quant_config = None
|
||||||
|
final_load_dtype = model_dtype
|
||||||
|
|
||||||
|
if use_llm_4bit:
|
||||||
|
bnb_compute_dtype = model_dtype
|
||||||
|
if attention_mode == 'sage': bnb_compute_dtype, final_load_dtype = torch.float32, torch.float32
|
||||||
|
quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=bnb_compute_dtype)
|
||||||
|
|
||||||
|
attn_implementation_for_load = "sdpa" if attention_mode == "sage" else attention_mode
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading model '{model_name}' with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'")
|
||||||
|
|
||||||
|
# UNIFIED MODEL LOADING LOGIC
|
||||||
|
from_pretrained_kwargs = {
|
||||||
|
"config": config,
|
||||||
|
"attn_implementation": attn_implementation_for_load,
|
||||||
|
"device_map": "auto" if quant_config else device,
|
||||||
|
"quantization_config": quant_config,
|
||||||
|
}
|
||||||
|
if _DTYPE_ARG_SUPPORTED:
|
||||||
|
from_pretrained_kwargs['dtype'] = final_load_dtype
|
||||||
|
else:
|
||||||
|
from_pretrained_kwargs['torch_dtype'] = final_load_dtype
|
||||||
|
|
||||||
|
if model_type == "standalone":
|
||||||
|
logger.info(f"Loading standalone model state_dict directly to device: {device}")
|
||||||
|
# loading the state dict directly to the target device
|
||||||
|
state_dict = comfy.utils.load_torch_file(model_info["path"], device=device)
|
||||||
|
from_pretrained_kwargs["state_dict"] = state_dict
|
||||||
|
|
||||||
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(model_path_or_none, **from_pretrained_kwargs)
|
||||||
|
|
||||||
|
if attention_mode == "sage":
|
||||||
|
if VibeVoiceLoader._check_gpu_for_sage_attention():
|
||||||
|
set_sage_attention(model)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Incompatible hardware/setup for SageAttention.")
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
setattr(model, "_llm_4bit", bool(quant_config))
|
||||||
|
LOADED_MODELS[cache_key] = (model, processor)
|
||||||
|
logger.info(f"Successfully configured model '{model_name}' with {attention_mode} attention")
|
||||||
|
return model, processor
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# It's not ideal to automatically reload the model. Let the user decide what to do in case of an error.
|
||||||
|
logger.error(f"Failed to load model '{model_name}' with {attention_mode} attention: {e}")
|
||||||
|
# if attention_mode in ["sage", "flash_attention_2"]: return VibeVoiceLoader.load_model(model_name, device, "sdpa", use_llm_4bit)
|
||||||
|
# elif attention_mode == "sdpa": return VibeVoiceLoader.load_model(model_name, device, "eager", use_llm_4bit)
|
||||||
|
# else:
|
||||||
|
raise RuntimeError(f"Failed to load model even with eager attention: {e}")
|
||||||
13
modules/model_info.py
Normal file
13
modules/model_info.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# This dictionary contains the configurations for official, downloadable models.
|
||||||
|
MODEL_CONFIGS = {
|
||||||
|
"VibeVoice-1.5B": {
|
||||||
|
"repo_id": "microsoft/VibeVoice-1.5B",
|
||||||
|
"size_gb": 3.0,
|
||||||
|
},
|
||||||
|
"VibeVoice-Large": {
|
||||||
|
"repo_id": "microsoft/VibeVoice-Large",
|
||||||
|
"size_gb": 17.4,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AVAILABLE_VIBEVOICE_MODELS = {}
|
||||||
49
modules/patcher.py
Normal file
49
modules/patcher.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.model_management as model_management
|
||||||
|
|
||||||
|
from .loader import LOADED_MODELS, logger
|
||||||
|
|
||||||
|
class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
|
||||||
|
"""Custom ModelPatcher for managing VibeVoice models in ComfyUI."""
|
||||||
|
def __init__(self, model, attention_mode="eager", *args, **kwargs):
|
||||||
|
super().__init__(model, *args, **kwargs)
|
||||||
|
self.attention_mode = attention_mode
|
||||||
|
self.cache_key = model.cache_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self):
|
||||||
|
"""Check if the model is currently loaded in memory."""
|
||||||
|
return hasattr(self, 'model') and self.model is not None and hasattr(self.model, 'model') and self.model.model is not None
|
||||||
|
|
||||||
|
def patch_model(self, device_to=None, *args, **kwargs):
|
||||||
|
target_device = self.load_device
|
||||||
|
if self.model.model is None:
|
||||||
|
logger.info(f"Loading VibeVoice models for '{self.model.model_pack_name}' to {target_device}...")
|
||||||
|
mode_names = {
|
||||||
|
"eager": "Eager (Most Compatible)",
|
||||||
|
"sdpa": "SDPA (Balanced Speed/Compatibility)",
|
||||||
|
"flash_attention_2": "Flash Attention 2 (Fastest)",
|
||||||
|
"sage": "SageAttention (Quantized High-Performance)",
|
||||||
|
}
|
||||||
|
logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}")
|
||||||
|
self.model.load_model(target_device, self.attention_mode)
|
||||||
|
self.model.model.to(target_device)
|
||||||
|
return super().patch_model(device_to=target_device, *args, **kwargs)
|
||||||
|
|
||||||
|
def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs):
|
||||||
|
if unpatch_weights:
|
||||||
|
logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' ({self.attention_mode}) to {device_to}...")
|
||||||
|
self.model.model = None
|
||||||
|
self.model.processor = None
|
||||||
|
|
||||||
|
if self.cache_key in LOADED_MODELS:
|
||||||
|
del LOADED_MODELS[self.cache_key]
|
||||||
|
logger.info(f"Cleared LOADED_MODELS cache for: {self.cache_key}")
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
model_management.soft_empty_cache()
|
||||||
|
|
||||||
|
return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs)
|
||||||
105
modules/utils.py
Normal file
105
modules/utils.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
|
from comfy.model_management import throw_exception_if_processing_interrupted
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError:
|
||||||
|
print("VibeVoice Node: `librosa` is not installed. Resampling of reference audio will not be available.")
|
||||||
|
librosa = None
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def set_vibevoice_seed(seed: int):
|
||||||
|
"""Sets the seed for torch, numpy, and random, handling large seeds for numpy."""
|
||||||
|
if seed == 0:
|
||||||
|
seed = random.randint(1, 0xffffffffffffffff)
|
||||||
|
|
||||||
|
MAX_NUMPY_SEED = 2**32 - 1
|
||||||
|
numpy_seed = seed % MAX_NUMPY_SEED
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(numpy_seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
def parse_script_1_based(script: str) -> tuple[list[tuple[int, str]], list[int]]:
|
||||||
|
"""
|
||||||
|
Parses a 1-based speaker script into a list of (speaker_id, text) tuples
|
||||||
|
and a list of unique speaker IDs in the order of their first appearance.
|
||||||
|
Internally, it converts speaker IDs to 0-based for the model.
|
||||||
|
"""
|
||||||
|
parsed_lines = []
|
||||||
|
speaker_ids_in_script = [] # This will store the 1-based IDs from the script
|
||||||
|
for line in script.strip().split("\n"):
|
||||||
|
if not (line := line.strip()): continue
|
||||||
|
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
speaker_id = int(match.group(1))
|
||||||
|
if speaker_id < 1:
|
||||||
|
logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'")
|
||||||
|
continue
|
||||||
|
text = ' ' + match.group(2).strip()
|
||||||
|
# Internally, the model expects 0-based indexing for speakers
|
||||||
|
internal_speaker_id = speaker_id - 1
|
||||||
|
parsed_lines.append((internal_speaker_id, text))
|
||||||
|
if speaker_id not in speaker_ids_in_script:
|
||||||
|
speaker_ids_in_script.append(speaker_id)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not parse line, skipping: '{line}'")
|
||||||
|
return parsed_lines, sorted(list(set(speaker_ids_in_script)))
|
||||||
|
|
||||||
|
def preprocess_comfy_audio(audio_dict: dict, target_sr: int = 24000) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Converts a ComfyUI AUDIO dict to a mono NumPy array, resampling if necessary.
|
||||||
|
"""
|
||||||
|
if not audio_dict: return None
|
||||||
|
waveform_tensor = audio_dict.get('waveform')
|
||||||
|
if waveform_tensor is None or waveform_tensor.numel() == 0: return None
|
||||||
|
|
||||||
|
waveform = waveform_tensor[0].cpu().numpy()
|
||||||
|
original_sr = audio_dict['sample_rate']
|
||||||
|
|
||||||
|
if waveform.ndim > 1:
|
||||||
|
waveform = np.mean(waveform, axis=0)
|
||||||
|
|
||||||
|
# Check for invalid values
|
||||||
|
if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)):
|
||||||
|
logger.error("Audio contains NaN or Inf values, replacing with zeros")
|
||||||
|
waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
|
||||||
|
# Ensure audio is not completely silent or has extreme values
|
||||||
|
if np.all(waveform == 0):
|
||||||
|
logger.warning("Audio waveform is completely silent")
|
||||||
|
|
||||||
|
# Normalize extreme values
|
||||||
|
max_val = np.abs(waveform).max()
|
||||||
|
if max_val > 10.0:
|
||||||
|
logger.warning(f"Audio values are very large (max: {max_val}), normalizing")
|
||||||
|
waveform = waveform / max_val
|
||||||
|
|
||||||
|
if original_sr != target_sr:
|
||||||
|
if librosa is None:
|
||||||
|
raise ImportError("`librosa` package is required for audio resampling. Please install it with `pip install librosa`.")
|
||||||
|
logger.warning(f"Resampling reference audio from {original_sr}Hz to {target_sr}Hz.")
|
||||||
|
waveform = librosa.resample(y=waveform, orig_sr=original_sr, target_sr=target_sr)
|
||||||
|
|
||||||
|
# Final check after resampling
|
||||||
|
if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)):
|
||||||
|
logger.error("Audio contains NaN or Inf after resampling, replacing with zeros")
|
||||||
|
waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
|
||||||
|
return waveform.astype(np.float32)
|
||||||
|
|
||||||
|
def check_for_interrupt():
|
||||||
|
try:
|
||||||
|
throw_exception_if_processing_interrupted()
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
return True
|
||||||
115
vibevoice/configs/default_VibeVoice-1.5B_config.json
Normal file
115
vibevoice/configs/default_VibeVoice-1.5B_config.json
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
{
|
||||||
|
"acoustic_vae_dim": 64,
|
||||||
|
"acoustic_tokenizer_config": {
|
||||||
|
"causal": true,
|
||||||
|
"channels": 1,
|
||||||
|
"conv_bias": true,
|
||||||
|
"conv_norm": "none",
|
||||||
|
"corpus_normalize": 0.0,
|
||||||
|
"decoder_depths": null,
|
||||||
|
"decoder_n_filters": 32,
|
||||||
|
"decoder_ratios": [
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"disable_last_norm": true,
|
||||||
|
"encoder_depths": "3-3-3-3-3-3-8",
|
||||||
|
"encoder_n_filters": 32,
|
||||||
|
"encoder_ratios": [
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"fix_std": 0.5,
|
||||||
|
"layer_scale_init_value": 1e-06,
|
||||||
|
"layernorm": "RMSNorm",
|
||||||
|
"layernorm_elementwise_affine": true,
|
||||||
|
"layernorm_eps": 1e-05,
|
||||||
|
"mixer_layer": "depthwise_conv",
|
||||||
|
"model_type": "vibevoice_acoustic_tokenizer",
|
||||||
|
"pad_mode": "constant",
|
||||||
|
"std_dist_type": "gaussian",
|
||||||
|
"vae_dim": 64,
|
||||||
|
"weight_init_value": 0.01
|
||||||
|
},
|
||||||
|
"architectures": [
|
||||||
|
"VibeVoiceForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"decoder_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 1536,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 8960,
|
||||||
|
"max_position_embeddings": 65536,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen2",
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": null,
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"sliding_window": null,
|
||||||
|
"tie_word_embeddings": true,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"use_cache": true,
|
||||||
|
"use_sliding_window": false,
|
||||||
|
"vocab_size": 151936
|
||||||
|
},
|
||||||
|
"diffusion_head_config": {
|
||||||
|
"ddpm_batch_mul": 4,
|
||||||
|
"ddpm_beta_schedule": "cosine",
|
||||||
|
"ddpm_num_inference_steps": 20,
|
||||||
|
"ddpm_num_steps": 1000,
|
||||||
|
"diffusion_type": "ddpm",
|
||||||
|
"head_ffn_ratio": 3.0,
|
||||||
|
"head_layers": 4,
|
||||||
|
"hidden_size": 1536,
|
||||||
|
"latent_size": 64,
|
||||||
|
"model_type": "vibevoice_diffusion_head",
|
||||||
|
"prediction_type": "v_prediction",
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"speech_vae_dim": 64
|
||||||
|
},
|
||||||
|
"model_type": "vibevoice",
|
||||||
|
"semantic_tokenizer_config": {
|
||||||
|
"causal": true,
|
||||||
|
"channels": 1,
|
||||||
|
"conv_bias": true,
|
||||||
|
"conv_norm": "none",
|
||||||
|
"corpus_normalize": 0.0,
|
||||||
|
"disable_last_norm": true,
|
||||||
|
"encoder_depths": "3-3-3-3-3-3-8",
|
||||||
|
"encoder_n_filters": 32,
|
||||||
|
"encoder_ratios": [
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"fix_std": 0,
|
||||||
|
"layer_scale_init_value": 1e-06,
|
||||||
|
"layernorm": "RMSNorm",
|
||||||
|
"layernorm_elementwise_affine": true,
|
||||||
|
"layernorm_eps": 1e-05,
|
||||||
|
"mixer_layer": "depthwise_conv",
|
||||||
|
"model_type": "vibevoice_semantic_tokenizer",
|
||||||
|
"pad_mode": "constant",
|
||||||
|
"std_dist_type": "none",
|
||||||
|
"vae_dim": 128,
|
||||||
|
"weight_init_value": 0.01
|
||||||
|
},
|
||||||
|
"semantic_vae_dim": 128,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.51.3"
|
||||||
|
}
|
||||||
116
vibevoice/configs/default_VibeVoice-Large_config.json
Normal file
116
vibevoice/configs/default_VibeVoice-Large_config.json
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
{
|
||||||
|
"acostic_vae_dim": 64,
|
||||||
|
"acoustic_tokenizer_config": {
|
||||||
|
"causal": true,
|
||||||
|
"channels": 1,
|
||||||
|
"conv_bias": true,
|
||||||
|
"conv_norm": "none",
|
||||||
|
"corpus_normalize": 0.0,
|
||||||
|
"decoder_depths": null,
|
||||||
|
"decoder_n_filters": 32,
|
||||||
|
"decoder_ratios": [
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"disable_last_norm": true,
|
||||||
|
"encoder_depths": "3-3-3-3-3-3-8",
|
||||||
|
"encoder_n_filters": 32,
|
||||||
|
"encoder_ratios": [
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"fix_std": 0.5,
|
||||||
|
"layer_scale_init_value": 1e-06,
|
||||||
|
"layernorm": "RMSNorm",
|
||||||
|
"layernorm_elementwise_affine": true,
|
||||||
|
"layernorm_eps": 1e-05,
|
||||||
|
"mixer_layer": "depthwise_conv",
|
||||||
|
"model_type": "vibevoice_acoustic_tokenizer",
|
||||||
|
"pad_mode": "constant",
|
||||||
|
"std_dist_type": "gaussian",
|
||||||
|
"vae_dim": 64,
|
||||||
|
"weight_init_value": 0.01
|
||||||
|
},
|
||||||
|
"architectures": [
|
||||||
|
"VibeVoiceForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"decoder_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3584,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 18944,
|
||||||
|
"max_position_embeddings": 32768,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen2",
|
||||||
|
"num_attention_heads": 28,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": null,
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"sliding_window": null,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"use_cache": true,
|
||||||
|
"use_mrope": false,
|
||||||
|
"use_sliding_window": false,
|
||||||
|
"vocab_size": 152064
|
||||||
|
},
|
||||||
|
"diffusion_head_config": {
|
||||||
|
"ddpm_batch_mul": 4,
|
||||||
|
"ddpm_beta_schedule": "cosine",
|
||||||
|
"ddpm_num_inference_steps": 20,
|
||||||
|
"ddpm_num_steps": 1000,
|
||||||
|
"diffusion_type": "ddpm",
|
||||||
|
"head_ffn_ratio": 3.0,
|
||||||
|
"head_layers": 4,
|
||||||
|
"hidden_size": 3584,
|
||||||
|
"latent_size": 64,
|
||||||
|
"model_type": "vibevoice_diffusion_head",
|
||||||
|
"prediction_type": "v_prediction",
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"speech_vae_dim": 64
|
||||||
|
},
|
||||||
|
"model_type": "vibevoice",
|
||||||
|
"semantic_tokenizer_config": {
|
||||||
|
"causal": true,
|
||||||
|
"channels": 1,
|
||||||
|
"conv_bias": true,
|
||||||
|
"conv_norm": "none",
|
||||||
|
"corpus_normalize": 0.0,
|
||||||
|
"disable_last_norm": true,
|
||||||
|
"encoder_depths": "3-3-3-3-3-3-8",
|
||||||
|
"encoder_n_filters": 32,
|
||||||
|
"encoder_ratios": [
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"fix_std": 0,
|
||||||
|
"layer_scale_init_value": 1e-06,
|
||||||
|
"layernorm": "RMSNorm",
|
||||||
|
"layernorm_elementwise_affine": true,
|
||||||
|
"layernorm_eps": 1e-05,
|
||||||
|
"mixer_layer": "depthwise_conv",
|
||||||
|
"model_type": "vibevoice_semantic_tokenizer",
|
||||||
|
"pad_mode": "constant",
|
||||||
|
"std_dist_type": "none",
|
||||||
|
"vae_dim": 128,
|
||||||
|
"weight_init_value": 0.01
|
||||||
|
},
|
||||||
|
"semantic_vae_dim": 128,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.51.3"
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
# Author: Wildminder
|
# Author: Wildminder
|
||||||
# Desc: SageAttention and patcher
|
# Desc: SageAttention and patcher
|
||||||
|
# License: Apache 2.0
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|||||||
@@ -1,395 +1,29 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import gc
|
||||||
import random
|
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import gc
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
import comfy.model_management as model_management
|
import comfy.model_management as model_management
|
||||||
import comfy.model_patcher
|
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
from comfy.model_management import throw_exception_if_processing_interrupted
|
|
||||||
|
|
||||||
# Import transformers and packaging to handle different library versions.
|
# Import from the dedicated model_info module
|
||||||
import transformers
|
from .modules.model_info import AVAILABLE_VIBEVOICE_MODELS
|
||||||
from packaging import version
|
from .modules.loader import VibeVoiceModelHandler, ATTENTION_MODES, VIBEVOICE_PATCHER_CACHE, cleanup_old_models
|
||||||
|
from .modules.patcher import VibeVoicePatcher
|
||||||
_transformers_version = version.parse(transformers.__version__)
|
from .modules.utils import parse_script_1_based, preprocess_comfy_audio, set_vibevoice_seed, check_for_interrupt
|
||||||
_DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0")
|
|
||||||
|
|
||||||
from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
|
|
||||||
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
|
||||||
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
|
||||||
from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
|
||||||
from .vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast
|
|
||||||
|
|
||||||
from . import SAGE_ATTENTION_AVAILABLE
|
|
||||||
if SAGE_ATTENTION_AVAILABLE:
|
|
||||||
from .vibevoice.modular.sage_attention_patch import set_sage_attention
|
|
||||||
|
|
||||||
try:
|
|
||||||
import librosa
|
|
||||||
except ImportError:
|
|
||||||
print("VibeVoice Node: `librosa` is not installed. Resampling of reference audio will not be available.")
|
|
||||||
librosa = None
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
LOADED_MODELS = {}
|
|
||||||
VIBEVOICE_PATCHER_CACHE = {}
|
|
||||||
|
|
||||||
MODEL_CONFIGS = {
|
|
||||||
"VibeVoice-1.5B": {
|
|
||||||
"repo_id": "microsoft/VibeVoice-1.5B",
|
|
||||||
"size_gb": 3.0,
|
|
||||||
"tokenizer_repo": "Qwen/Qwen2.5-1.5B"
|
|
||||||
},
|
|
||||||
"VibeVoice-Large": {
|
|
||||||
"repo_id": "aoi-ot/VibeVoice-Large",
|
|
||||||
"size_gb": 17.4,
|
|
||||||
"tokenizer_repo": "Qwen/Qwen2.5-7B"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"]
|
|
||||||
if SAGE_ATTENTION_AVAILABLE:
|
|
||||||
ATTENTION_MODES.append("sage")
|
|
||||||
|
|
||||||
def cleanup_old_models(keep_cache_key=None):
|
|
||||||
"""Clean up old models, optionally keeping one specific model loaded"""
|
|
||||||
global LOADED_MODELS, VIBEVOICE_PATCHER_CACHE
|
|
||||||
|
|
||||||
keys_to_remove = []
|
|
||||||
|
|
||||||
# Clear LOADED_MODELS
|
|
||||||
for key in list(LOADED_MODELS.keys()):
|
|
||||||
if key != keep_cache_key:
|
|
||||||
keys_to_remove.append(key)
|
|
||||||
del LOADED_MODELS[key]
|
|
||||||
|
|
||||||
# Clear VIBEVOICE_PATCHER_CACHE - but more carefully
|
|
||||||
for key in list(VIBEVOICE_PATCHER_CACHE.keys()):
|
|
||||||
if key != keep_cache_key:
|
|
||||||
try:
|
|
||||||
patcher = VIBEVOICE_PATCHER_CACHE[key]
|
|
||||||
if hasattr(patcher, 'model') and patcher.model:
|
|
||||||
patcher.model.model = None
|
|
||||||
patcher.model.processor = None
|
|
||||||
del VIBEVOICE_PATCHER_CACHE[key]
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error cleaning up patcher {key}: {e}")
|
|
||||||
|
|
||||||
if keys_to_remove:
|
|
||||||
logger.info(f"Cleaned up cached models: {keys_to_remove}")
|
|
||||||
gc.collect()
|
|
||||||
model_management.soft_empty_cache()
|
|
||||||
|
|
||||||
class VibeVoiceModelHandler(torch.nn.Module):
|
|
||||||
"""A torch.nn.Module wrapper to hold the VibeVoice model and processor."""
|
|
||||||
def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False):
|
|
||||||
super().__init__()
|
|
||||||
self.model_pack_name = model_pack_name
|
|
||||||
self.attention_mode = attention_mode
|
|
||||||
self.use_llm_4bit = use_llm_4bit
|
|
||||||
self.cache_key = f"{model_pack_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}"
|
|
||||||
self.model = None
|
|
||||||
self.processor = None
|
|
||||||
self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3))
|
|
||||||
|
|
||||||
def load_model(self, device, attention_mode="eager"):
|
|
||||||
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit)
|
|
||||||
self.model.to(device)
|
|
||||||
|
|
||||||
class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
|
|
||||||
"""Custom ModelPatcher for managing VibeVoice models in ComfyUI."""
|
|
||||||
def __init__(self, model, attention_mode="eager", *args, **kwargs):
|
|
||||||
super().__init__(model, *args, **kwargs)
|
|
||||||
self.attention_mode = attention_mode
|
|
||||||
self.cache_key = model.cache_key
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_loaded(self):
|
|
||||||
"""Check if the model is currently loaded in memory."""
|
|
||||||
return hasattr(self, 'model') and self.model is not None and hasattr(self.model, 'model') and self.model.model is not None
|
|
||||||
|
|
||||||
def patch_model(self, device_to=None, *args, **kwargs):
|
|
||||||
target_device = self.load_device
|
|
||||||
if self.model.model is None:
|
|
||||||
logger.info(f"Loading VibeVoice models for '{self.model.model_pack_name}' to {target_device}...")
|
|
||||||
mode_names = {
|
|
||||||
"eager": "Eager (Most Compatible)",
|
|
||||||
"sdpa": "SDPA (Balanced Speed/Compatibility)",
|
|
||||||
"flash_attention_2": "Flash Attention 2 (Fastest)",
|
|
||||||
"sage": "SageAttention (Quantized High-Performance)",
|
|
||||||
}
|
|
||||||
logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}")
|
|
||||||
self.model.load_model(target_device, self.attention_mode)
|
|
||||||
self.model.model.to(target_device)
|
|
||||||
return super().patch_model(device_to=target_device, *args, **kwargs)
|
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs):
|
|
||||||
if unpatch_weights:
|
|
||||||
logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' ({self.attention_mode}) to {device_to}...")
|
|
||||||
self.model.model = None
|
|
||||||
self.model.processor = None
|
|
||||||
|
|
||||||
if self.cache_key in LOADED_MODELS:
|
|
||||||
del LOADED_MODELS[self.cache_key]
|
|
||||||
logger.info(f"Cleared LOADED_MODELS cache for: {self.cache_key}")
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
model_management.soft_empty_cache()
|
|
||||||
|
|
||||||
return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs)
|
|
||||||
|
|
||||||
class VibeVoiceLoader:
|
|
||||||
@staticmethod
|
|
||||||
def get_model_path(model_name: str):
|
|
||||||
if model_name not in MODEL_CONFIGS:
|
|
||||||
raise ValueError(f"Unknown VibeVoice model: {model_name}")
|
|
||||||
|
|
||||||
vibevoice_path = os.path.join(folder_paths.get_folder_paths("tts")[0], "VibeVoice")
|
|
||||||
model_path = os.path.join(vibevoice_path, model_name)
|
|
||||||
|
|
||||||
index_file = os.path.join(model_path, "model.safetensors.index.json")
|
|
||||||
if not os.path.exists(index_file):
|
|
||||||
print(f"Downloading VibeVoice model: {model_name}...")
|
|
||||||
repo_id = MODEL_CONFIGS[model_name]["repo_id"]
|
|
||||||
snapshot_download(repo_id=repo_id, local_dir=model_path)
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _check_gpu_for_sage_attention():
|
|
||||||
"""Check if the current GPU is compatible with SageAttention."""
|
|
||||||
if not SAGE_ATTENTION_AVAILABLE:
|
|
||||||
return False
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
return False
|
|
||||||
major, _ = torch.cuda.get_device_capability()
|
|
||||||
if major < 8:
|
|
||||||
logger.warning(f"Your GPU (compute capability {major}.x) does not support SageAttention, which requires CC 8.0+. Sage option will be disabled.")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False):
|
|
||||||
|
|
||||||
if use_llm_4bit and attention_mode in ["eager", "flash_attention_2"]:
|
|
||||||
logger.warning(f"Attention mode '{attention_mode}' is not recommended with 4-bit quantization. Falling back to 'sdpa' for stability and performance.")
|
|
||||||
attention_mode = "sdpa"
|
|
||||||
|
|
||||||
if attention_mode not in ATTENTION_MODES:
|
|
||||||
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
|
|
||||||
attention_mode = "eager"
|
|
||||||
|
|
||||||
cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}"
|
|
||||||
|
|
||||||
if cache_key in LOADED_MODELS:
|
|
||||||
logger.info(f"Using cached model with {attention_mode} attention and q4={use_llm_4bit}")
|
|
||||||
return LOADED_MODELS[cache_key]
|
|
||||||
|
|
||||||
model_path = VibeVoiceLoader.get_model_path(model_name)
|
|
||||||
|
|
||||||
logger.info(f"Loading VibeVoice model components from: {model_path}")
|
|
||||||
|
|
||||||
tokenizer_repo = MODEL_CONFIGS[model_name].get("tokenizer_repo")
|
|
||||||
tokenizer_file_path = os.path.join(model_path, "tokenizer.json")
|
|
||||||
# Check if tokenizer.json exists locally. If not, download it directly to the model folder.
|
|
||||||
if not os.path.exists(tokenizer_file_path):
|
|
||||||
logger.info(f"tokenizer.json not found in {model_path}. Downloading from '{tokenizer_repo}'...")
|
|
||||||
try:
|
|
||||||
hf_hub_download(
|
|
||||||
repo_id=tokenizer_repo,
|
|
||||||
filename="tokenizer.json",
|
|
||||||
local_dir=model_path,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to download tokenizer.json: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
vibevoice_tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file_path)
|
|
||||||
audio_processor = VibeVoiceTokenizerProcessor()
|
|
||||||
processor = VibeVoiceProcessor(tokenizer=vibevoice_tokenizer, audio_processor=audio_processor)
|
|
||||||
|
|
||||||
# Base dtype for full precision and memory-optimized 4-bit
|
|
||||||
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
|
||||||
model_dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
model_dtype = torch.float16
|
|
||||||
|
|
||||||
quant_config = None
|
|
||||||
final_load_dtype = model_dtype
|
|
||||||
|
|
||||||
if use_llm_4bit:
|
|
||||||
# Default to bfloat16/float16 for memory savings
|
|
||||||
bnb_compute_dtype = model_dtype
|
|
||||||
|
|
||||||
# SageAttention is numerically sensitive and requires fp32 compute dtype for stability
|
|
||||||
# SDPA is more robust and can use bf16.
|
|
||||||
if attention_mode == 'sage':
|
|
||||||
logger.info("Using SageAttention with 4-bit quant. Forcing fp32 compute dtype for maximum stability.")
|
|
||||||
bnb_compute_dtype = torch.float32
|
|
||||||
final_load_dtype = torch.float32
|
|
||||||
else:
|
|
||||||
logger.info(f"Using {attention_mode} with 4-bit quant. Using {model_dtype} compute dtype for memory efficiency.")
|
|
||||||
|
|
||||||
quant_config = BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
bnb_4bit_quant_type="nf4",
|
|
||||||
bnb_4bit_use_double_quant=True,
|
|
||||||
bnb_4bit_compute_dtype=bnb_compute_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_implementation_for_load = "sdpa" if attention_mode == "sage" else attention_mode
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"Loading model with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'")
|
|
||||||
# Build a dictionary of keyword arguments for from_pretrained.
|
|
||||||
from_pretrained_kwargs = {
|
|
||||||
"attn_implementation": attn_implementation_for_load,
|
|
||||||
"device_map": "auto" if quant_config else device,
|
|
||||||
"quantization_config": quant_config,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Use the correct dtype argument based on the transformers version.
|
|
||||||
if _DTYPE_ARG_SUPPORTED:
|
|
||||||
from_pretrained_kwargs['dtype'] = final_load_dtype
|
|
||||||
else:
|
|
||||||
from_pretrained_kwargs['torch_dtype'] = final_load_dtype
|
|
||||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
**from_pretrained_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mode == "sage":
|
|
||||||
if VibeVoiceLoader._check_gpu_for_sage_attention():
|
|
||||||
logger.info("Applying SageAttention patch to the model...")
|
|
||||||
set_sage_attention(model)
|
|
||||||
else:
|
|
||||||
logger.error("Cannot apply SageAttention due to incompatible GPU. Falling back.")
|
|
||||||
raise RuntimeError("Incompatible hardware/setup for SageAttention.")
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
setattr(model, "_llm_4bit", bool(quant_config))
|
|
||||||
|
|
||||||
LOADED_MODELS[cache_key] = (model, processor)
|
|
||||||
logger.info(f"Successfully configured model with {attention_mode} attention")
|
|
||||||
return model, processor
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load model with {attention_mode} attention: {e}")
|
|
||||||
# Fallback logic
|
|
||||||
if attention_mode in ["sage", "flash_attention_2"]:
|
|
||||||
logger.info("Attempting fallback to SDPA...")
|
|
||||||
return VibeVoiceLoader.load_model(model_name, device, "sdpa", use_llm_4bit)
|
|
||||||
elif attention_mode == "sdpa":
|
|
||||||
logger.info("Attempting fallback to eager...")
|
|
||||||
return VibeVoiceLoader.load_model(model_name, device, "eager", use_llm_4bit)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Failed to load model even with eager attention: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def set_vibevoice_seed(seed: int):
|
|
||||||
"""Sets the seed for torch, numpy, and random, handling large seeds for numpy."""
|
|
||||||
if seed == 0:
|
|
||||||
seed = random.randint(1, 0xffffffffffffffff)
|
|
||||||
|
|
||||||
MAX_NUMPY_SEED = 2**32 - 1
|
|
||||||
numpy_seed = seed % MAX_NUMPY_SEED
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
np.random.seed(numpy_seed)
|
|
||||||
random.seed(seed)
|
|
||||||
|
|
||||||
def parse_script_1_based(script: str) -> tuple[list[tuple[int, str]], list[int]]:
|
|
||||||
"""
|
|
||||||
Parses a 1-based speaker script into a list of (speaker_id, text) tuples
|
|
||||||
and a list of unique speaker IDs in the order of their first appearance.
|
|
||||||
Internally, it converts speaker IDs to 0-based for the model.
|
|
||||||
"""
|
|
||||||
parsed_lines = []
|
|
||||||
speaker_ids_in_script = [] # This will store the 1-based IDs from the script
|
|
||||||
for line in script.strip().split("\n"):
|
|
||||||
if not (line := line.strip()): continue
|
|
||||||
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
|
|
||||||
if match:
|
|
||||||
speaker_id = int(match.group(1))
|
|
||||||
if speaker_id < 1:
|
|
||||||
logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'")
|
|
||||||
continue
|
|
||||||
text = ' ' + match.group(2).strip()
|
|
||||||
# Internally, the model expects 0-based indexing for speakers
|
|
||||||
internal_speaker_id = speaker_id - 1
|
|
||||||
parsed_lines.append((internal_speaker_id, text))
|
|
||||||
if speaker_id not in speaker_ids_in_script:
|
|
||||||
speaker_ids_in_script.append(speaker_id)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Could not parse line, skipping: '{line}'")
|
|
||||||
return parsed_lines, sorted(list(set(speaker_ids_in_script)))
|
|
||||||
|
|
||||||
def preprocess_comfy_audio(audio_dict: dict, target_sr: int = 24000) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Converts a ComfyUI AUDIO dict to a mono NumPy array, resampling if necessary.
|
|
||||||
"""
|
|
||||||
if not audio_dict: return None
|
|
||||||
waveform_tensor = audio_dict.get('waveform')
|
|
||||||
if waveform_tensor is None or waveform_tensor.numel() == 0: return None
|
|
||||||
|
|
||||||
waveform = waveform_tensor[0].cpu().numpy()
|
|
||||||
original_sr = audio_dict['sample_rate']
|
|
||||||
|
|
||||||
if waveform.ndim > 1:
|
|
||||||
waveform = np.mean(waveform, axis=0)
|
|
||||||
|
|
||||||
# Check for invalid values
|
|
||||||
if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)):
|
|
||||||
logger.error("Audio contains NaN or Inf values, replacing with zeros")
|
|
||||||
waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
|
|
||||||
# Ensure audio is not completely silent or has extreme values
|
|
||||||
if np.all(waveform == 0):
|
|
||||||
logger.warning("Audio waveform is completely silent")
|
|
||||||
|
|
||||||
# Normalize extreme values
|
|
||||||
max_val = np.abs(waveform).max()
|
|
||||||
if max_val > 10.0:
|
|
||||||
logger.warning(f"Audio values are very large (max: {max_val}), normalizing")
|
|
||||||
waveform = waveform / max_val
|
|
||||||
|
|
||||||
if original_sr != target_sr:
|
|
||||||
if librosa is None:
|
|
||||||
raise ImportError("`librosa` package is required for audio resampling. Please install it with `pip install librosa`.")
|
|
||||||
logger.warning(f"Resampling reference audio from {original_sr}Hz to {target_sr}Hz.")
|
|
||||||
waveform = librosa.resample(y=waveform, orig_sr=original_sr, target_sr=target_sr)
|
|
||||||
|
|
||||||
# Final check after resampling
|
|
||||||
if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)):
|
|
||||||
logger.error("Audio contains NaN or Inf after resampling, replacing with zeros")
|
|
||||||
waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
|
|
||||||
return waveform.astype(np.float32)
|
|
||||||
|
|
||||||
def check_for_interrupt():
|
|
||||||
try:
|
|
||||||
throw_exception_if_processing_interrupted()
|
|
||||||
return False
|
|
||||||
except:
|
|
||||||
return True
|
|
||||||
|
|
||||||
class VibeVoiceTTSNode:
|
class VibeVoiceTTSNode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
|
model_names = list(AVAILABLE_VIBEVOICE_MODELS.keys())
|
||||||
|
if not model_names:
|
||||||
|
model_names.append("No models found in models/tts/VibeVoice")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"model_name": (list(MODEL_CONFIGS.keys()), {
|
"model_name": (model_names, {
|
||||||
"tooltip": "Select the VibeVoice model to use. Models will be downloaded automatically if not present."
|
"tooltip": "Select the VibeVoice model to use. Official models will be downloaded automatically."
|
||||||
}),
|
}),
|
||||||
"text": ("STRING", {
|
"text": ("STRING", {
|
||||||
"multiline": True,
|
"multiline": True,
|
||||||
@@ -405,7 +39,7 @@ class VibeVoiceTTSNode:
|
|||||||
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest), Sage (quantized)"
|
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest), Sage (quantized)"
|
||||||
}),
|
}),
|
||||||
"cfg_scale": ("FLOAT", {
|
"cfg_scale": ("FLOAT", {
|
||||||
"default": 1.3, "min": 1.0, "max": 3.0, "step": 0.05,
|
"default": 1.3, "min": 1.0, "max": 10.0, "step": 0.05,
|
||||||
"tooltip": "Classifier-Free Guidance scale. Higher values increase adherence to the voice prompt but may reduce naturalness. Recommended: 1.3"
|
"tooltip": "Classifier-Free Guidance scale. Higher values increase adherence to the voice prompt but may reduce naturalness. Recommended: 1.3"
|
||||||
}),
|
}),
|
||||||
"inference_steps": ("INT", {
|
"inference_steps": ("INT", {
|
||||||
@@ -450,16 +84,13 @@ class VibeVoiceTTSNode:
|
|||||||
CATEGORY = "audio/tts"
|
CATEGORY = "audio/tts"
|
||||||
|
|
||||||
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, force_offload, **kwargs):
|
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, force_offload, **kwargs):
|
||||||
|
|
||||||
actual_attention_mode = attention_mode
|
actual_attention_mode = attention_mode
|
||||||
if quantize_llm_4bit and attention_mode in ["eager", "flash_attention_2"]:
|
if quantize_llm_4bit and attention_mode in ["eager", "flash_attention_2"]:
|
||||||
actual_attention_mode = "sdpa"
|
actual_attention_mode = "sdpa"
|
||||||
|
|
||||||
cache_key = f"{model_name}_attn_{actual_attention_mode}_q4_{int(quantize_llm_4bit)}"
|
cache_key = f"{model_name}_attn_{actual_attention_mode}_q4_{int(quantize_llm_4bit)}"
|
||||||
|
|
||||||
# Clean up old models when switching to a different model
|
|
||||||
if cache_key not in VIBEVOICE_PATCHER_CACHE:
|
if cache_key not in VIBEVOICE_PATCHER_CACHE:
|
||||||
# Only keep models that are currently being requested
|
|
||||||
cleanup_old_models(keep_cache_key=cache_key)
|
cleanup_old_models(keep_cache_key=cache_key)
|
||||||
|
|
||||||
model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit)
|
model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit)
|
||||||
@@ -501,7 +132,6 @@ class VibeVoiceTTSNode:
|
|||||||
return_tensors="pt", return_attention_mask=True
|
return_tensors="pt", return_attention_mask=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate inputs before moving to GPU
|
|
||||||
for key, value in inputs.items():
|
for key, value in inputs.items():
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
if torch.any(torch.isnan(value)) or torch.any(torch.isinf(value)):
|
if torch.any(torch.isnan(value)) or torch.any(torch.isinf(value)):
|
||||||
@@ -519,9 +149,6 @@ class VibeVoiceTTSNode:
|
|||||||
if top_k > 0:
|
if top_k > 0:
|
||||||
generation_config['top_k'] = top_k
|
generation_config['top_k'] = top_k
|
||||||
|
|
||||||
# cause float() error for q4+eager
|
|
||||||
# model = model.float() IS REMOVED
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pbar = ProgressBar(inference_steps)
|
pbar = ProgressBar(inference_steps)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user