mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-01-26 14:39:45 +00:00
212 lines
11 KiB
Python
212 lines
11 KiB
Python
import os
|
|
import torch
|
|
import gc
|
|
import json
|
|
import logging
|
|
from huggingface_hub import hf_hub_download, snapshot_download
|
|
|
|
import comfy.utils
|
|
import folder_paths
|
|
import comfy.model_management as model_management
|
|
|
|
import transformers
|
|
from packaging import version
|
|
|
|
_transformers_version = version.parse(transformers.__version__)
|
|
_DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0")
|
|
|
|
from transformers import BitsAndBytesConfig
|
|
from ..vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
|
|
from ..vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
|
from ..vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
|
from ..vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
|
from ..vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast
|
|
|
|
from .model_info import AVAILABLE_VIBEVOICE_MODELS, MODEL_CONFIGS
|
|
from .. import SAGE_ATTENTION_AVAILABLE
|
|
if SAGE_ATTENTION_AVAILABLE:
|
|
from ..vibevoice.modular.sage_attention_patch import set_sage_attention
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
LOADED_MODELS = {}
|
|
VIBEVOICE_PATCHER_CACHE = {}
|
|
|
|
ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"]
|
|
if SAGE_ATTENTION_AVAILABLE:
|
|
ATTENTION_MODES.append("sage")
|
|
|
|
def cleanup_old_models(keep_cache_key=None):
|
|
global LOADED_MODELS, VIBEVOICE_PATCHER_CACHE
|
|
keys_to_remove = []
|
|
for key in list(LOADED_MODELS.keys()):
|
|
if key != keep_cache_key:
|
|
keys_to_remove.append(key)
|
|
del LOADED_MODELS[key]
|
|
for key in list(VIBEVOICE_PATCHER_CACHE.keys()):
|
|
if key != keep_cache_key:
|
|
try:
|
|
patcher = VIBEVOICE_PATCHER_CACHE[key]
|
|
if hasattr(patcher, 'model') and patcher.model:
|
|
patcher.model.model = None
|
|
patcher.model.processor = None
|
|
del VIBEVOICE_PATCHER_CACHE[key]
|
|
except Exception as e:
|
|
logger.warning(f"Error cleaning up patcher {key}: {e}")
|
|
if keys_to_remove:
|
|
logger.info(f"Cleaned up cached models: {keys_to_remove}")
|
|
gc.collect()
|
|
model_management.soft_empty_cache()
|
|
|
|
|
|
class VibeVoiceModelHandler(torch.nn.Module):
|
|
def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False):
|
|
super().__init__()
|
|
self.model_pack_name = model_pack_name
|
|
self.attention_mode = attention_mode
|
|
self.use_llm_4bit = use_llm_4bit
|
|
self.cache_key = f"{self.model_pack_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}"
|
|
self.model = None
|
|
self.processor = None
|
|
info = AVAILABLE_VIBEVOICE_MODELS.get(model_pack_name, {})
|
|
size_gb = MODEL_CONFIGS.get(model_pack_name, {}).get("size_gb", 4.0)
|
|
self.size = int(size_gb * (1024**3))
|
|
def load_model(self, device, attention_mode="eager"):
|
|
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit)
|
|
if self.model.device != device:
|
|
self.model.to(device)
|
|
|
|
class VibeVoiceLoader:
|
|
@staticmethod
|
|
def _check_gpu_for_sage_attention():
|
|
if not SAGE_ATTENTION_AVAILABLE: return False
|
|
if not torch.cuda.is_available(): return False
|
|
major, _ = torch.cuda.get_device_capability()
|
|
if major < 8:
|
|
logger.warning(f"Your GPU (compute capability {major}.x) does not support SageAttention, which requires CC 8.0+. Sage option will be disabled.")
|
|
return False
|
|
return True
|
|
|
|
@staticmethod
|
|
def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False):
|
|
if model_name not in AVAILABLE_VIBEVOICE_MODELS:
|
|
raise ValueError(f"Unknown VibeVoice model: {model_name}. Available models: {list(AVAILABLE_VIBEVOICE_MODELS.keys())}")
|
|
|
|
if use_llm_4bit and attention_mode in ["eager", "flash_attention_2"]:
|
|
logger.warning(f"Attention mode '{attention_mode}' is not recommended with 4-bit quantization. Falling back to 'sdpa' for stability and performance.")
|
|
attention_mode = "sdpa"
|
|
if attention_mode not in ATTENTION_MODES:
|
|
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
|
|
attention_mode = "eager"
|
|
|
|
cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}"
|
|
if cache_key in LOADED_MODELS:
|
|
logger.info(f"Using cached model with {attention_mode} attention and q4={use_llm_4bit}")
|
|
return LOADED_MODELS[cache_key]
|
|
|
|
model_info = AVAILABLE_VIBEVOICE_MODELS[model_name]
|
|
model_type = model_info["type"]
|
|
vibevoice_base_path = os.path.join(folder_paths.get_folder_paths("tts")[0], "VibeVoice")
|
|
|
|
model_path_or_none = None
|
|
config_path = None
|
|
preprocessor_config_path = None
|
|
tokenizer_dir = None
|
|
|
|
if model_type == "official":
|
|
model_path_or_none = os.path.join(vibevoice_base_path, model_name)
|
|
if not os.path.exists(os.path.join(model_path_or_none, "model.safetensors.index.json")):
|
|
logger.info(f"Downloading official VibeVoice model: {model_name}...")
|
|
snapshot_download(repo_id=model_info["repo_id"], local_dir=model_path_or_none, local_dir_use_symlinks=False)
|
|
config_path = os.path.join(model_path_or_none, "config.json")
|
|
preprocessor_config_path = os.path.join(model_path_or_none, "preprocessor_config.json")
|
|
tokenizer_dir = model_path_or_none
|
|
elif model_type == "local_dir":
|
|
model_path_or_none = model_info["path"]
|
|
config_path = os.path.join(model_path_or_none, "config.json")
|
|
preprocessor_config_path = os.path.join(model_path_or_none, "preprocessor_config.json")
|
|
tokenizer_dir = model_path_or_none
|
|
elif model_type == "standalone":
|
|
model_path_or_none = None # IMPORTANT: This must be None when loading from state_dict
|
|
config_path = os.path.splitext(model_info["path"])[0] + ".config.json"
|
|
preprocessor_config_path = os.path.splitext(model_info["path"])[0] + ".preprocessor.json"
|
|
tokenizer_dir = os.path.dirname(model_info["path"])
|
|
|
|
if os.path.exists(config_path):
|
|
config = VibeVoiceConfig.from_pretrained(config_path)
|
|
else:
|
|
fallback_name = "default_VibeVoice-Large_config.json" if "large" in model_name.lower() else "default_VibeVoice-1.5B_config.json"
|
|
fallback_path = os.path.join(os.path.dirname(__file__), "..", "vibevoice", "configs", fallback_name)
|
|
logger.warning(f"Config not found for '{model_name}'. Using fallback: {fallback_name}")
|
|
config = VibeVoiceConfig.from_pretrained(fallback_path)
|
|
|
|
# Processor & Tokenizer setup
|
|
tokenizer_repo = model_info["tokenizer_repo"]
|
|
tokenizer_file_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
|
if not os.path.exists(tokenizer_file_path):
|
|
logger.info(f"tokenizer.json not found. Downloading from '{tokenizer_repo}'...")
|
|
hf_hub_download(repo_id=tokenizer_repo, filename="tokenizer.json", local_dir=tokenizer_dir, local_dir_use_symlinks=False)
|
|
vibevoice_tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file_path)
|
|
|
|
processor_config_data = {}
|
|
if os.path.exists(preprocessor_config_path):
|
|
with open(preprocessor_config_path, 'r', encoding='utf-8') as f: processor_config_data = json.load(f)
|
|
|
|
audio_processor = VibeVoiceTokenizerProcessor()
|
|
processor = VibeVoiceProcessor(tokenizer=vibevoice_tokenizer, audio_processor=audio_processor, speech_tok_compress_ratio=processor_config_data.get("speech_tok_compress_ratio", 3200), db_normalize=processor_config_data.get("db_normalize", True))
|
|
|
|
# Model Loading Prep
|
|
if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): model_dtype = torch.bfloat16
|
|
else: model_dtype = torch.float16
|
|
quant_config = None
|
|
final_load_dtype = model_dtype
|
|
|
|
if use_llm_4bit:
|
|
bnb_compute_dtype = model_dtype
|
|
if attention_mode == 'sage': bnb_compute_dtype, final_load_dtype = torch.float32, torch.float32
|
|
quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=bnb_compute_dtype)
|
|
|
|
attn_implementation_for_load = "sdpa" if attention_mode == "sage" else attention_mode
|
|
|
|
try:
|
|
logger.info(f"Loading model '{model_name}' with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'")
|
|
|
|
# UNIFIED MODEL LOADING LOGIC
|
|
from_pretrained_kwargs = {
|
|
"config": config,
|
|
"attn_implementation": attn_implementation_for_load,
|
|
"device_map": "auto" if quant_config else device,
|
|
"quantization_config": quant_config,
|
|
}
|
|
if _DTYPE_ARG_SUPPORTED:
|
|
from_pretrained_kwargs['dtype'] = final_load_dtype
|
|
else:
|
|
from_pretrained_kwargs['torch_dtype'] = final_load_dtype
|
|
|
|
if model_type == "standalone":
|
|
logger.info(f"Loading standalone model state_dict directly to device: {device}")
|
|
# loading the state dict directly to the target device
|
|
state_dict = comfy.utils.load_torch_file(model_info["path"], device=device)
|
|
from_pretrained_kwargs["state_dict"] = state_dict
|
|
|
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(model_path_or_none, **from_pretrained_kwargs)
|
|
|
|
if attention_mode == "sage":
|
|
if VibeVoiceLoader._check_gpu_for_sage_attention():
|
|
set_sage_attention(model)
|
|
else:
|
|
raise RuntimeError("Incompatible hardware/setup for SageAttention.")
|
|
|
|
model.eval()
|
|
setattr(model, "_llm_4bit", bool(quant_config))
|
|
LOADED_MODELS[cache_key] = (model, processor)
|
|
logger.info(f"Successfully configured model '{model_name}' with {attention_mode} attention")
|
|
return model, processor
|
|
|
|
except Exception as e:
|
|
# It's not ideal to automatically reload the model. Let the user decide what to do in case of an error.
|
|
logger.error(f"Failed to load model '{model_name}' with {attention_mode} attention: {e}")
|
|
# if attention_mode in ["sage", "flash_attention_2"]: return VibeVoiceLoader.load_model(model_name, device, "sdpa", use_llm_4bit)
|
|
# elif attention_mode == "sdpa": return VibeVoiceLoader.load_model(model_name, device, "eager", use_llm_4bit)
|
|
# else:
|
|
raise RuntimeError(f"Failed to load model even with eager attention: {e}") |