From 4c9785da8b3ca3d3f862659480c35df258b3ebb1 Mon Sep 17 00:00:00 2001 From: WildAi <2853742+wildminder@users.noreply.github.com> Date: Wed, 10 Sep 2025 12:06:26 +0300 Subject: [PATCH] major refactoring --- __init__.py | 83 +++- modules/__init__.py | 0 modules/loader.py | 212 +++++++++ modules/model_info.py | 13 + modules/patcher.py | 49 +++ modules/utils.py | 105 +++++ .../default_VibeVoice-1.5B_config.json | 115 +++++ .../default_VibeVoice-Large_config.json | 116 +++++ vibevoice/modular/sage_attention_patch.py | 3 +- vibevoice_nodes.py | 401 +----------------- 10 files changed, 699 insertions(+), 398 deletions(-) create mode 100644 modules/__init__.py create mode 100644 modules/loader.py create mode 100644 modules/model_info.py create mode 100644 modules/patcher.py create mode 100644 modules/utils.py create mode 100644 vibevoice/configs/default_VibeVoice-1.5B_config.json create mode 100644 vibevoice/configs/default_VibeVoice-Large_config.json diff --git a/__init__.py b/__init__.py index 2706c8f..311d2b7 100644 --- a/__init__.py +++ b/__init__.py @@ -1,6 +1,8 @@ import os import sys import logging +import folder_paths +import json try: import sageattention @@ -12,34 +14,95 @@ current_dir = os.path.dirname(os.path.abspath(__file__)) if current_dir not in sys.path: sys.path.append(current_dir) -import folder_paths +from .modules.model_info import AVAILABLE_VIBEVOICE_MODELS, MODEL_CONFIGS -from .vibevoice_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS - -# Configure a logger for the entire custom node package +# Configure a logger logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger.propagate = False - if not logger.hasHandlers(): handler = logging.StreamHandler() formatter = logging.Formatter(f"[ComfyUI-VibeVoice] %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) +# This is just the *name* of the subdirectory, not the full path. +VIBEVOICE_SUBDIR_NAME = "VibeVoice" -VIBEVOICE_MODEL_SUBDIR = os.path.join("tts", "VibeVoice") +# This is the *primary* path where official models will be downloaded. +primary_vibevoice_models_path = os.path.join(folder_paths.models_dir, "tts", VIBEVOICE_SUBDIR_NAME) +os.makedirs(primary_vibevoice_models_path, exist_ok=True) -vibevoice_models_full_path = os.path.join(folder_paths.models_dir, VIBEVOICE_MODEL_SUBDIR) -os.makedirs(vibevoice_models_full_path, exist_ok=True) - -# Register the tts/VibeVoice path with ComfyUI +# Register the tts path type with ComfyUI so get_folder_paths works tts_path = os.path.join(folder_paths.models_dir, "tts") if "tts" not in folder_paths.folder_names_and_paths: supported_exts = folder_paths.supported_pt_extensions.union({".safetensors", ".json"}) folder_paths.folder_names_and_paths["tts"] = ([tts_path], supported_exts) else: + # Ensure the default path is in the list if it's not already if tts_path not in folder_paths.folder_names_and_paths["tts"][0]: folder_paths.folder_names_and_paths["tts"][0].append(tts_path) +# The logic for dynamic model discovery +# ToDo: optimize finding + +# official models that can be auto-downloaded +for model_name, config in MODEL_CONFIGS.items(): + AVAILABLE_VIBEVOICE_MODELS[model_name] = { + "type": "official", + "repo_id": config["repo_id"], + "tokenizer_repo": "Qwen/Qwen2.5-7B" if "Large" in model_name else "Qwen/Qwen2.5-1.5B" + } + +# just workaround, default + custom +vibevoice_search_paths = [] +# Use ComfyUI's API to get all registered 'tts' folders +for tts_folder in folder_paths.get_folder_paths("tts"): + potential_path = os.path.join(tts_folder, VIBEVOICE_SUBDIR_NAME) + if os.path.isdir(potential_path) and potential_path not in vibevoice_search_paths: + vibevoice_search_paths.append(potential_path) + +# Add the primary path just in case it wasn't registered for some reason +if primary_vibevoice_models_path not in vibevoice_search_paths: + vibevoice_search_paths.insert(0, primary_vibevoice_models_path) + +# Messy... Discover all local models in the search paths +for search_path in vibevoice_search_paths: + logger.info(f"Scanning for VibeVoice models in: {search_path}") + if not os.path.exists(search_path): continue + for item in os.listdir(search_path): + item_path = os.path.join(search_path, item) + + # Case 1: we have a standard HF directory + if os.path.isdir(item_path): + model_name = item + if model_name in AVAILABLE_VIBEVOICE_MODELS: continue + + config_exists = os.path.exists(os.path.join(item_path, "config.json")) + weights_exist = os.path.exists(os.path.join(item_path, "model.safetensors.index.json")) or any(f.endswith(('.safetensors', '.bin')) for f in os.listdir(item_path)) + + if config_exists and weights_exist: + tokenizer_repo = "Qwen/Qwen2.5-7B" if "large" in model_name.lower() else "Qwen/Qwen2.5-1.5B" + AVAILABLE_VIBEVOICE_MODELS[model_name] = { + "type": "local_dir", + "path": item_path, + "tokenizer_repo": tokenizer_repo + } + + # Case 2: Item is a standalone file + elif os.path.isfile(item_path) and any(item.endswith(ext) for ext in folder_paths.supported_pt_extensions): + model_name = os.path.splitext(item)[0] + if model_name in AVAILABLE_VIBEVOICE_MODELS: continue + + tokenizer_repo = "Qwen/Qwen2.5-7B" if "large" in model_name.lower() else "Qwen/Qwen2.5-1.5B" + AVAILABLE_VIBEVOICE_MODELS[model_name] = { + "type": "standalone", + "path": item_path, + "tokenizer_repo": tokenizer_repo + } + +logger.info(f"Discovered VibeVoice models: {sorted(list(AVAILABLE_VIBEVOICE_MODELS.keys()))}") + +from .vibevoice_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/loader.py b/modules/loader.py new file mode 100644 index 0000000..7c915c1 --- /dev/null +++ b/modules/loader.py @@ -0,0 +1,212 @@ +import os +import torch +import gc +import json +import logging +from huggingface_hub import hf_hub_download, snapshot_download + +import comfy.utils +import folder_paths +import comfy.model_management as model_management + +import transformers +from packaging import version + +_transformers_version = version.parse(transformers.__version__) +_DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0") + +from transformers import BitsAndBytesConfig +from ..vibevoice.modular.configuration_vibevoice import VibeVoiceConfig +from ..vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference +from ..vibevoice.processor.vibevoice_processor import VibeVoiceProcessor +from ..vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor +from ..vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast + +from .model_info import AVAILABLE_VIBEVOICE_MODELS, MODEL_CONFIGS +from .. import SAGE_ATTENTION_AVAILABLE +if SAGE_ATTENTION_AVAILABLE: + from ..vibevoice.modular.sage_attention_patch import set_sage_attention + +logger = logging.getLogger(__name__) + +LOADED_MODELS = {} +VIBEVOICE_PATCHER_CACHE = {} + +ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"] +if SAGE_ATTENTION_AVAILABLE: + ATTENTION_MODES.append("sage") + +def cleanup_old_models(keep_cache_key=None): + global LOADED_MODELS, VIBEVOICE_PATCHER_CACHE + keys_to_remove = [] + for key in list(LOADED_MODELS.keys()): + if key != keep_cache_key: + keys_to_remove.append(key) + del LOADED_MODELS[key] + for key in list(VIBEVOICE_PATCHER_CACHE.keys()): + if key != keep_cache_key: + try: + patcher = VIBEVOICE_PATCHER_CACHE[key] + if hasattr(patcher, 'model') and patcher.model: + patcher.model.model = None + patcher.model.processor = None + del VIBEVOICE_PATCHER_CACHE[key] + except Exception as e: + logger.warning(f"Error cleaning up patcher {key}: {e}") + if keys_to_remove: + logger.info(f"Cleaned up cached models: {keys_to_remove}") + gc.collect() + model_management.soft_empty_cache() + + +class VibeVoiceModelHandler(torch.nn.Module): + def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False): + super().__init__() + self.model_pack_name = model_pack_name + self.attention_mode = attention_mode + self.use_llm_4bit = use_llm_4bit + self.cache_key = f"{self.model_pack_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}" + self.model = None + self.processor = None + info = AVAILABLE_VIBEVOICE_MODELS.get(model_pack_name, {}) + size_gb = MODEL_CONFIGS.get(model_pack_name, {}).get("size_gb", 4.0) + self.size = int(size_gb * (1024**3)) + def load_model(self, device, attention_mode="eager"): + self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit) + if self.model.device != device: + self.model.to(device) + +class VibeVoiceLoader: + @staticmethod + def _check_gpu_for_sage_attention(): + if not SAGE_ATTENTION_AVAILABLE: return False + if not torch.cuda.is_available(): return False + major, _ = torch.cuda.get_device_capability() + if major < 8: + logger.warning(f"Your GPU (compute capability {major}.x) does not support SageAttention, which requires CC 8.0+. Sage option will be disabled.") + return False + return True + + @staticmethod + def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False): + if model_name not in AVAILABLE_VIBEVOICE_MODELS: + raise ValueError(f"Unknown VibeVoice model: {model_name}. Available models: {list(AVAILABLE_VIBEVOICE_MODELS.keys())}") + + if use_llm_4bit and attention_mode in ["eager", "flash_attention_2"]: + logger.warning(f"Attention mode '{attention_mode}' is not recommended with 4-bit quantization. Falling back to 'sdpa' for stability and performance.") + attention_mode = "sdpa" + if attention_mode not in ATTENTION_MODES: + logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager") + attention_mode = "eager" + + cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}" + if cache_key in LOADED_MODELS: + logger.info(f"Using cached model with {attention_mode} attention and q4={use_llm_4bit}") + return LOADED_MODELS[cache_key] + + model_info = AVAILABLE_VIBEVOICE_MODELS[model_name] + model_type = model_info["type"] + vibevoice_base_path = os.path.join(folder_paths.get_folder_paths("tts")[0], "VibeVoice") + + model_path_or_none = None + config_path = None + preprocessor_config_path = None + tokenizer_dir = None + + if model_type == "official": + model_path_or_none = os.path.join(vibevoice_base_path, model_name) + if not os.path.exists(os.path.join(model_path_or_none, "model.safetensors.index.json")): + logger.info(f"Downloading official VibeVoice model: {model_name}...") + snapshot_download(repo_id=model_info["repo_id"], local_dir=model_path_or_none, local_dir_use_symlinks=False) + config_path = os.path.join(model_path_or_none, "config.json") + preprocessor_config_path = os.path.join(model_path_or_none, "preprocessor_config.json") + tokenizer_dir = model_path_or_none + elif model_type == "local_dir": + model_path_or_none = model_info["path"] + config_path = os.path.join(model_path_or_none, "config.json") + preprocessor_config_path = os.path.join(model_path_or_none, "preprocessor_config.json") + tokenizer_dir = model_path_or_none + elif model_type == "standalone": + model_path_or_none = None # IMPORTANT: This must be None when loading from state_dict + config_path = os.path.splitext(model_info["path"])[0] + ".config.json" + preprocessor_config_path = os.path.splitext(model_info["path"])[0] + ".preprocessor.json" + tokenizer_dir = os.path.dirname(model_info["path"]) + + if os.path.exists(config_path): + config = VibeVoiceConfig.from_pretrained(config_path) + else: + fallback_name = "default_VibeVoice-Large_config.json" if "large" in model_name.lower() else "default_VibeVoice-1.5B_config.json" + fallback_path = os.path.join(os.path.dirname(__file__), "..", "vibevoice", "configs", fallback_name) + logger.warning(f"Config not found for '{model_name}'. Using fallback: {fallback_name}") + config = VibeVoiceConfig.from_pretrained(fallback_path) + + # Processor & Tokenizer setup + tokenizer_repo = model_info["tokenizer_repo"] + tokenizer_file_path = os.path.join(tokenizer_dir, "tokenizer.json") + if not os.path.exists(tokenizer_file_path): + logger.info(f"tokenizer.json not found. Downloading from '{tokenizer_repo}'...") + hf_hub_download(repo_id=tokenizer_repo, filename="tokenizer.json", local_dir=tokenizer_dir, local_dir_use_symlinks=False) + vibevoice_tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file_path) + + processor_config_data = {} + if os.path.exists(preprocessor_config_path): + with open(preprocessor_config_path, 'r', encoding='utf-8') as f: processor_config_data = json.load(f) + + audio_processor = VibeVoiceTokenizerProcessor() + processor = VibeVoiceProcessor(tokenizer=vibevoice_tokenizer, audio_processor=audio_processor, speech_tok_compress_ratio=processor_config_data.get("speech_tok_compress_ratio", 3200), db_normalize=processor_config_data.get("db_normalize", True)) + + # Model Loading Prep + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): model_dtype = torch.bfloat16 + else: model_dtype = torch.float16 + quant_config = None + final_load_dtype = model_dtype + + if use_llm_4bit: + bnb_compute_dtype = model_dtype + if attention_mode == 'sage': bnb_compute_dtype, final_load_dtype = torch.float32, torch.float32 + quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=bnb_compute_dtype) + + attn_implementation_for_load = "sdpa" if attention_mode == "sage" else attention_mode + + try: + logger.info(f"Loading model '{model_name}' with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'") + + # UNIFIED MODEL LOADING LOGIC + from_pretrained_kwargs = { + "config": config, + "attn_implementation": attn_implementation_for_load, + "device_map": "auto" if quant_config else device, + "quantization_config": quant_config, + } + if _DTYPE_ARG_SUPPORTED: + from_pretrained_kwargs['dtype'] = final_load_dtype + else: + from_pretrained_kwargs['torch_dtype'] = final_load_dtype + + if model_type == "standalone": + logger.info(f"Loading standalone model state_dict directly to device: {device}") + # loading the state dict directly to the target device + state_dict = comfy.utils.load_torch_file(model_info["path"], device=device) + from_pretrained_kwargs["state_dict"] = state_dict + + model = VibeVoiceForConditionalGenerationInference.from_pretrained(model_path_or_none, **from_pretrained_kwargs) + + if attention_mode == "sage": + if VibeVoiceLoader._check_gpu_for_sage_attention(): + set_sage_attention(model) + else: + raise RuntimeError("Incompatible hardware/setup for SageAttention.") + + model.eval() + setattr(model, "_llm_4bit", bool(quant_config)) + LOADED_MODELS[cache_key] = (model, processor) + logger.info(f"Successfully configured model '{model_name}' with {attention_mode} attention") + return model, processor + + except Exception as e: + # It's not ideal to automatically reload the model. Let the user decide what to do in case of an error. + logger.error(f"Failed to load model '{model_name}' with {attention_mode} attention: {e}") + # if attention_mode in ["sage", "flash_attention_2"]: return VibeVoiceLoader.load_model(model_name, device, "sdpa", use_llm_4bit) + # elif attention_mode == "sdpa": return VibeVoiceLoader.load_model(model_name, device, "eager", use_llm_4bit) + # else: + raise RuntimeError(f"Failed to load model even with eager attention: {e}") \ No newline at end of file diff --git a/modules/model_info.py b/modules/model_info.py new file mode 100644 index 0000000..25e501f --- /dev/null +++ b/modules/model_info.py @@ -0,0 +1,13 @@ +# This dictionary contains the configurations for official, downloadable models. +MODEL_CONFIGS = { + "VibeVoice-1.5B": { + "repo_id": "microsoft/VibeVoice-1.5B", + "size_gb": 3.0, + }, + "VibeVoice-Large": { + "repo_id": "microsoft/VibeVoice-Large", + "size_gb": 17.4, + } +} + +AVAILABLE_VIBEVOICE_MODELS = {} \ No newline at end of file diff --git a/modules/patcher.py b/modules/patcher.py new file mode 100644 index 0000000..bb48e8a --- /dev/null +++ b/modules/patcher.py @@ -0,0 +1,49 @@ +import torch +import gc +import logging +import comfy.model_patcher +import comfy.model_management as model_management + +from .loader import LOADED_MODELS, logger + +class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): + """Custom ModelPatcher for managing VibeVoice models in ComfyUI.""" + def __init__(self, model, attention_mode="eager", *args, **kwargs): + super().__init__(model, *args, **kwargs) + self.attention_mode = attention_mode + self.cache_key = model.cache_key + + @property + def is_loaded(self): + """Check if the model is currently loaded in memory.""" + return hasattr(self, 'model') and self.model is not None and hasattr(self.model, 'model') and self.model.model is not None + + def patch_model(self, device_to=None, *args, **kwargs): + target_device = self.load_device + if self.model.model is None: + logger.info(f"Loading VibeVoice models for '{self.model.model_pack_name}' to {target_device}...") + mode_names = { + "eager": "Eager (Most Compatible)", + "sdpa": "SDPA (Balanced Speed/Compatibility)", + "flash_attention_2": "Flash Attention 2 (Fastest)", + "sage": "SageAttention (Quantized High-Performance)", + } + logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}") + self.model.load_model(target_device, self.attention_mode) + self.model.model.to(target_device) + return super().patch_model(device_to=target_device, *args, **kwargs) + + def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs): + if unpatch_weights: + logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' ({self.attention_mode}) to {device_to}...") + self.model.model = None + self.model.processor = None + + if self.cache_key in LOADED_MODELS: + del LOADED_MODELS[self.cache_key] + logger.info(f"Cleared LOADED_MODELS cache for: {self.cache_key}") + + gc.collect() + model_management.soft_empty_cache() + + return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs) \ No newline at end of file diff --git a/modules/utils.py b/modules/utils.py new file mode 100644 index 0000000..01fc7c8 --- /dev/null +++ b/modules/utils.py @@ -0,0 +1,105 @@ +import re +import torch +import numpy as np +import random +import logging + +from comfy.utils import ProgressBar +from comfy.model_management import throw_exception_if_processing_interrupted + +try: + import librosa +except ImportError: + print("VibeVoice Node: `librosa` is not installed. Resampling of reference audio will not be available.") + librosa = None + +logger = logging.getLogger(__name__) + +def set_vibevoice_seed(seed: int): + """Sets the seed for torch, numpy, and random, handling large seeds for numpy.""" + if seed == 0: + seed = random.randint(1, 0xffffffffffffffff) + + MAX_NUMPY_SEED = 2**32 - 1 + numpy_seed = seed % MAX_NUMPY_SEED + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(numpy_seed) + random.seed(seed) + +def parse_script_1_based(script: str) -> tuple[list[tuple[int, str]], list[int]]: + """ + Parses a 1-based speaker script into a list of (speaker_id, text) tuples + and a list of unique speaker IDs in the order of their first appearance. + Internally, it converts speaker IDs to 0-based for the model. + """ + parsed_lines = [] + speaker_ids_in_script = [] # This will store the 1-based IDs from the script + for line in script.strip().split("\n"): + if not (line := line.strip()): continue + match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE) + if match: + speaker_id = int(match.group(1)) + if speaker_id < 1: + logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'") + continue + text = ' ' + match.group(2).strip() + # Internally, the model expects 0-based indexing for speakers + internal_speaker_id = speaker_id - 1 + parsed_lines.append((internal_speaker_id, text)) + if speaker_id not in speaker_ids_in_script: + speaker_ids_in_script.append(speaker_id) + else: + logger.warning(f"Could not parse line, skipping: '{line}'") + return parsed_lines, sorted(list(set(speaker_ids_in_script))) + +def preprocess_comfy_audio(audio_dict: dict, target_sr: int = 24000) -> np.ndarray: + """ + Converts a ComfyUI AUDIO dict to a mono NumPy array, resampling if necessary. + """ + if not audio_dict: return None + waveform_tensor = audio_dict.get('waveform') + if waveform_tensor is None or waveform_tensor.numel() == 0: return None + + waveform = waveform_tensor[0].cpu().numpy() + original_sr = audio_dict['sample_rate'] + + if waveform.ndim > 1: + waveform = np.mean(waveform, axis=0) + + # Check for invalid values + if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): + logger.error("Audio contains NaN or Inf values, replacing with zeros") + waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) + + # Ensure audio is not completely silent or has extreme values + if np.all(waveform == 0): + logger.warning("Audio waveform is completely silent") + + # Normalize extreme values + max_val = np.abs(waveform).max() + if max_val > 10.0: + logger.warning(f"Audio values are very large (max: {max_val}), normalizing") + waveform = waveform / max_val + + if original_sr != target_sr: + if librosa is None: + raise ImportError("`librosa` package is required for audio resampling. Please install it with `pip install librosa`.") + logger.warning(f"Resampling reference audio from {original_sr}Hz to {target_sr}Hz.") + waveform = librosa.resample(y=waveform, orig_sr=original_sr, target_sr=target_sr) + + # Final check after resampling + if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): + logger.error("Audio contains NaN or Inf after resampling, replacing with zeros") + waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) + + return waveform.astype(np.float32) + +def check_for_interrupt(): + try: + throw_exception_if_processing_interrupted() + return False + except: + return True \ No newline at end of file diff --git a/vibevoice/configs/default_VibeVoice-1.5B_config.json b/vibevoice/configs/default_VibeVoice-1.5B_config.json new file mode 100644 index 0000000..17feb55 --- /dev/null +++ b/vibevoice/configs/default_VibeVoice-1.5B_config.json @@ -0,0 +1,115 @@ +{ + "acoustic_vae_dim": 64, + "acoustic_tokenizer_config": { + "causal": true, + "channels": 1, + "conv_bias": true, + "conv_norm": "none", + "corpus_normalize": 0.0, + "decoder_depths": null, + "decoder_n_filters": 32, + "decoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "disable_last_norm": true, + "encoder_depths": "3-3-3-3-3-3-8", + "encoder_n_filters": 32, + "encoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "fix_std": 0.5, + "layer_scale_init_value": 1e-06, + "layernorm": "RMSNorm", + "layernorm_elementwise_affine": true, + "layernorm_eps": 1e-05, + "mixer_layer": "depthwise_conv", + "model_type": "vibevoice_acoustic_tokenizer", + "pad_mode": "constant", + "std_dist_type": "gaussian", + "vae_dim": 64, + "weight_init_value": 0.01 + }, + "architectures": [ + "VibeVoiceForConditionalGeneration" + ], + "decoder_config": { + "attention_dropout": 0.0, + "hidden_act": "silu", + "hidden_size": 1536, + "initializer_range": 0.02, + "intermediate_size": 8960, + "max_position_embeddings": 65536, + "max_window_layers": 28, + "model_type": "qwen2", + "num_attention_heads": 12, + "num_hidden_layers": 28, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936 + }, + "diffusion_head_config": { + "ddpm_batch_mul": 4, + "ddpm_beta_schedule": "cosine", + "ddpm_num_inference_steps": 20, + "ddpm_num_steps": 1000, + "diffusion_type": "ddpm", + "head_ffn_ratio": 3.0, + "head_layers": 4, + "hidden_size": 1536, + "latent_size": 64, + "model_type": "vibevoice_diffusion_head", + "prediction_type": "v_prediction", + "rms_norm_eps": 1e-05, + "speech_vae_dim": 64 + }, + "model_type": "vibevoice", + "semantic_tokenizer_config": { + "causal": true, + "channels": 1, + "conv_bias": true, + "conv_norm": "none", + "corpus_normalize": 0.0, + "disable_last_norm": true, + "encoder_depths": "3-3-3-3-3-3-8", + "encoder_n_filters": 32, + "encoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "fix_std": 0, + "layer_scale_init_value": 1e-06, + "layernorm": "RMSNorm", + "layernorm_elementwise_affine": true, + "layernorm_eps": 1e-05, + "mixer_layer": "depthwise_conv", + "model_type": "vibevoice_semantic_tokenizer", + "pad_mode": "constant", + "std_dist_type": "none", + "vae_dim": 128, + "weight_init_value": 0.01 + }, + "semantic_vae_dim": 128, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.3" +} diff --git a/vibevoice/configs/default_VibeVoice-Large_config.json b/vibevoice/configs/default_VibeVoice-Large_config.json new file mode 100644 index 0000000..5efb7d3 --- /dev/null +++ b/vibevoice/configs/default_VibeVoice-Large_config.json @@ -0,0 +1,116 @@ +{ + "acostic_vae_dim": 64, + "acoustic_tokenizer_config": { + "causal": true, + "channels": 1, + "conv_bias": true, + "conv_norm": "none", + "corpus_normalize": 0.0, + "decoder_depths": null, + "decoder_n_filters": 32, + "decoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "disable_last_norm": true, + "encoder_depths": "3-3-3-3-3-3-8", + "encoder_n_filters": 32, + "encoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "fix_std": 0.5, + "layer_scale_init_value": 1e-06, + "layernorm": "RMSNorm", + "layernorm_elementwise_affine": true, + "layernorm_eps": 1e-05, + "mixer_layer": "depthwise_conv", + "model_type": "vibevoice_acoustic_tokenizer", + "pad_mode": "constant", + "std_dist_type": "gaussian", + "vae_dim": 64, + "weight_init_value": 0.01 + }, + "architectures": [ + "VibeVoiceForConditionalGeneration" + ], + "decoder_config": { + "attention_dropout": 0.0, + "hidden_act": "silu", + "hidden_size": 3584, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "qwen2", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "torch_dtype": "bfloat16", + "use_cache": true, + "use_mrope": false, + "use_sliding_window": false, + "vocab_size": 152064 + }, + "diffusion_head_config": { + "ddpm_batch_mul": 4, + "ddpm_beta_schedule": "cosine", + "ddpm_num_inference_steps": 20, + "ddpm_num_steps": 1000, + "diffusion_type": "ddpm", + "head_ffn_ratio": 3.0, + "head_layers": 4, + "hidden_size": 3584, + "latent_size": 64, + "model_type": "vibevoice_diffusion_head", + "prediction_type": "v_prediction", + "rms_norm_eps": 1e-05, + "speech_vae_dim": 64 + }, + "model_type": "vibevoice", + "semantic_tokenizer_config": { + "causal": true, + "channels": 1, + "conv_bias": true, + "conv_norm": "none", + "corpus_normalize": 0.0, + "disable_last_norm": true, + "encoder_depths": "3-3-3-3-3-3-8", + "encoder_n_filters": 32, + "encoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "fix_std": 0, + "layer_scale_init_value": 1e-06, + "layernorm": "RMSNorm", + "layernorm_elementwise_affine": true, + "layernorm_eps": 1e-05, + "mixer_layer": "depthwise_conv", + "model_type": "vibevoice_semantic_tokenizer", + "pad_mode": "constant", + "std_dist_type": "none", + "vae_dim": 128, + "weight_init_value": 0.01 + }, + "semantic_vae_dim": 128, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.3" +} diff --git a/vibevoice/modular/sage_attention_patch.py b/vibevoice/modular/sage_attention_patch.py index 4111d62..96dca11 100644 --- a/vibevoice/modular/sage_attention_patch.py +++ b/vibevoice/modular/sage_attention_patch.py @@ -1,5 +1,6 @@ # Author: Wildminder # Desc: SageAttention and patcher +# License: Apache 2.0 import torch from typing import Optional, Tuple @@ -145,4 +146,4 @@ def set_sage_attention(model): for module in model.modules(): if isinstance(module, Qwen2Attention): - module.forward = sage_attention_forward.__get__(module, Qwen2Attention) + module.forward = sage_attention_forward.__get__(module, Qwen2Attention) \ No newline at end of file diff --git a/vibevoice_nodes.py b/vibevoice_nodes.py index 749d87a..0ba8ddd 100644 --- a/vibevoice_nodes.py +++ b/vibevoice_nodes.py @@ -1,395 +1,29 @@ -import os -import re import torch -import numpy as np -import random -from huggingface_hub import hf_hub_download, snapshot_download +import gc import logging -import gc - -import folder_paths import comfy.model_management as model_management -import comfy.model_patcher from comfy.utils import ProgressBar -from comfy.model_management import throw_exception_if_processing_interrupted -# Import transformers and packaging to handle different library versions. -import transformers -from packaging import version - -_transformers_version = version.parse(transformers.__version__) -_DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0") - -from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig -from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference -from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor -from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor -from .vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast - -from . import SAGE_ATTENTION_AVAILABLE -if SAGE_ATTENTION_AVAILABLE: - from .vibevoice.modular.sage_attention_patch import set_sage_attention - -try: - import librosa -except ImportError: - print("VibeVoice Node: `librosa` is not installed. Resampling of reference audio will not be available.") - librosa = None +# Import from the dedicated model_info module +from .modules.model_info import AVAILABLE_VIBEVOICE_MODELS +from .modules.loader import VibeVoiceModelHandler, ATTENTION_MODES, VIBEVOICE_PATCHER_CACHE, cleanup_old_models +from .modules.patcher import VibeVoicePatcher +from .modules.utils import parse_script_1_based, preprocess_comfy_audio, set_vibevoice_seed, check_for_interrupt logger = logging.getLogger(__name__) -LOADED_MODELS = {} -VIBEVOICE_PATCHER_CACHE = {} - -MODEL_CONFIGS = { - "VibeVoice-1.5B": { - "repo_id": "microsoft/VibeVoice-1.5B", - "size_gb": 3.0, - "tokenizer_repo": "Qwen/Qwen2.5-1.5B" - }, - "VibeVoice-Large": { - "repo_id": "aoi-ot/VibeVoice-Large", - "size_gb": 17.4, - "tokenizer_repo": "Qwen/Qwen2.5-7B" - } -} - -ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"] -if SAGE_ATTENTION_AVAILABLE: - ATTENTION_MODES.append("sage") - -def cleanup_old_models(keep_cache_key=None): - """Clean up old models, optionally keeping one specific model loaded""" - global LOADED_MODELS, VIBEVOICE_PATCHER_CACHE - - keys_to_remove = [] - - # Clear LOADED_MODELS - for key in list(LOADED_MODELS.keys()): - if key != keep_cache_key: - keys_to_remove.append(key) - del LOADED_MODELS[key] - - # Clear VIBEVOICE_PATCHER_CACHE - but more carefully - for key in list(VIBEVOICE_PATCHER_CACHE.keys()): - if key != keep_cache_key: - try: - patcher = VIBEVOICE_PATCHER_CACHE[key] - if hasattr(patcher, 'model') and patcher.model: - patcher.model.model = None - patcher.model.processor = None - del VIBEVOICE_PATCHER_CACHE[key] - except Exception as e: - logger.warning(f"Error cleaning up patcher {key}: {e}") - - if keys_to_remove: - logger.info(f"Cleaned up cached models: {keys_to_remove}") - gc.collect() - model_management.soft_empty_cache() - -class VibeVoiceModelHandler(torch.nn.Module): - """A torch.nn.Module wrapper to hold the VibeVoice model and processor.""" - def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False): - super().__init__() - self.model_pack_name = model_pack_name - self.attention_mode = attention_mode - self.use_llm_4bit = use_llm_4bit - self.cache_key = f"{model_pack_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}" - self.model = None - self.processor = None - self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3)) - - def load_model(self, device, attention_mode="eager"): - self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit) - self.model.to(device) - -class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): - """Custom ModelPatcher for managing VibeVoice models in ComfyUI.""" - def __init__(self, model, attention_mode="eager", *args, **kwargs): - super().__init__(model, *args, **kwargs) - self.attention_mode = attention_mode - self.cache_key = model.cache_key - - @property - def is_loaded(self): - """Check if the model is currently loaded in memory.""" - return hasattr(self, 'model') and self.model is not None and hasattr(self.model, 'model') and self.model.model is not None - - def patch_model(self, device_to=None, *args, **kwargs): - target_device = self.load_device - if self.model.model is None: - logger.info(f"Loading VibeVoice models for '{self.model.model_pack_name}' to {target_device}...") - mode_names = { - "eager": "Eager (Most Compatible)", - "sdpa": "SDPA (Balanced Speed/Compatibility)", - "flash_attention_2": "Flash Attention 2 (Fastest)", - "sage": "SageAttention (Quantized High-Performance)", - } - logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}") - self.model.load_model(target_device, self.attention_mode) - self.model.model.to(target_device) - return super().patch_model(device_to=target_device, *args, **kwargs) - - def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs): - if unpatch_weights: - logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' ({self.attention_mode}) to {device_to}...") - self.model.model = None - self.model.processor = None - - if self.cache_key in LOADED_MODELS: - del LOADED_MODELS[self.cache_key] - logger.info(f"Cleared LOADED_MODELS cache for: {self.cache_key}") - - gc.collect() - model_management.soft_empty_cache() - - return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs) - -class VibeVoiceLoader: - @staticmethod - def get_model_path(model_name: str): - if model_name not in MODEL_CONFIGS: - raise ValueError(f"Unknown VibeVoice model: {model_name}") - - vibevoice_path = os.path.join(folder_paths.get_folder_paths("tts")[0], "VibeVoice") - model_path = os.path.join(vibevoice_path, model_name) - - index_file = os.path.join(model_path, "model.safetensors.index.json") - if not os.path.exists(index_file): - print(f"Downloading VibeVoice model: {model_name}...") - repo_id = MODEL_CONFIGS[model_name]["repo_id"] - snapshot_download(repo_id=repo_id, local_dir=model_path) - return model_path - - @staticmethod - def _check_gpu_for_sage_attention(): - """Check if the current GPU is compatible with SageAttention.""" - if not SAGE_ATTENTION_AVAILABLE: - return False - if not torch.cuda.is_available(): - return False - major, _ = torch.cuda.get_device_capability() - if major < 8: - logger.warning(f"Your GPU (compute capability {major}.x) does not support SageAttention, which requires CC 8.0+. Sage option will be disabled.") - return False - return True - - @staticmethod - def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False): - - if use_llm_4bit and attention_mode in ["eager", "flash_attention_2"]: - logger.warning(f"Attention mode '{attention_mode}' is not recommended with 4-bit quantization. Falling back to 'sdpa' for stability and performance.") - attention_mode = "sdpa" - - if attention_mode not in ATTENTION_MODES: - logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager") - attention_mode = "eager" - - cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}" - - if cache_key in LOADED_MODELS: - logger.info(f"Using cached model with {attention_mode} attention and q4={use_llm_4bit}") - return LOADED_MODELS[cache_key] - - model_path = VibeVoiceLoader.get_model_path(model_name) - - logger.info(f"Loading VibeVoice model components from: {model_path}") - - tokenizer_repo = MODEL_CONFIGS[model_name].get("tokenizer_repo") - tokenizer_file_path = os.path.join(model_path, "tokenizer.json") - # Check if tokenizer.json exists locally. If not, download it directly to the model folder. - if not os.path.exists(tokenizer_file_path): - logger.info(f"tokenizer.json not found in {model_path}. Downloading from '{tokenizer_repo}'...") - try: - hf_hub_download( - repo_id=tokenizer_repo, - filename="tokenizer.json", - local_dir=model_path, - ) - except Exception as e: - logger.error(f"Failed to download tokenizer.json: {e}") - raise - - - vibevoice_tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file_path) - audio_processor = VibeVoiceTokenizerProcessor() - processor = VibeVoiceProcessor(tokenizer=vibevoice_tokenizer, audio_processor=audio_processor) - - # Base dtype for full precision and memory-optimized 4-bit - if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): - model_dtype = torch.bfloat16 - else: - model_dtype = torch.float16 - - quant_config = None - final_load_dtype = model_dtype - - if use_llm_4bit: - # Default to bfloat16/float16 for memory savings - bnb_compute_dtype = model_dtype - - # SageAttention is numerically sensitive and requires fp32 compute dtype for stability - # SDPA is more robust and can use bf16. - if attention_mode == 'sage': - logger.info("Using SageAttention with 4-bit quant. Forcing fp32 compute dtype for maximum stability.") - bnb_compute_dtype = torch.float32 - final_load_dtype = torch.float32 - else: - logger.info(f"Using {attention_mode} with 4-bit quant. Using {model_dtype} compute dtype for memory efficiency.") - - quant_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=True, - bnb_4bit_compute_dtype=bnb_compute_dtype, - ) - - attn_implementation_for_load = "sdpa" if attention_mode == "sage" else attention_mode - - try: - logger.info(f"Loading model with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'") - # Build a dictionary of keyword arguments for from_pretrained. - from_pretrained_kwargs = { - "attn_implementation": attn_implementation_for_load, - "device_map": "auto" if quant_config else device, - "quantization_config": quant_config, - } - - # Use the correct dtype argument based on the transformers version. - if _DTYPE_ARG_SUPPORTED: - from_pretrained_kwargs['dtype'] = final_load_dtype - else: - from_pretrained_kwargs['torch_dtype'] = final_load_dtype - model = VibeVoiceForConditionalGenerationInference.from_pretrained( - model_path, - **from_pretrained_kwargs - ) - - if attention_mode == "sage": - if VibeVoiceLoader._check_gpu_for_sage_attention(): - logger.info("Applying SageAttention patch to the model...") - set_sage_attention(model) - else: - logger.error("Cannot apply SageAttention due to incompatible GPU. Falling back.") - raise RuntimeError("Incompatible hardware/setup for SageAttention.") - - model.eval() - setattr(model, "_llm_4bit", bool(quant_config)) - - LOADED_MODELS[cache_key] = (model, processor) - logger.info(f"Successfully configured model with {attention_mode} attention") - return model, processor - - except Exception as e: - logger.error(f"Failed to load model with {attention_mode} attention: {e}") - # Fallback logic - if attention_mode in ["sage", "flash_attention_2"]: - logger.info("Attempting fallback to SDPA...") - return VibeVoiceLoader.load_model(model_name, device, "sdpa", use_llm_4bit) - elif attention_mode == "sdpa": - logger.info("Attempting fallback to eager...") - return VibeVoiceLoader.load_model(model_name, device, "eager", use_llm_4bit) - else: - raise RuntimeError(f"Failed to load model even with eager attention: {e}") - - -def set_vibevoice_seed(seed: int): - """Sets the seed for torch, numpy, and random, handling large seeds for numpy.""" - if seed == 0: - seed = random.randint(1, 0xffffffffffffffff) - - MAX_NUMPY_SEED = 2**32 - 1 - numpy_seed = seed % MAX_NUMPY_SEED - - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - np.random.seed(numpy_seed) - random.seed(seed) - -def parse_script_1_based(script: str) -> tuple[list[tuple[int, str]], list[int]]: - """ - Parses a 1-based speaker script into a list of (speaker_id, text) tuples - and a list of unique speaker IDs in the order of their first appearance. - Internally, it converts speaker IDs to 0-based for the model. - """ - parsed_lines = [] - speaker_ids_in_script = [] # This will store the 1-based IDs from the script - for line in script.strip().split("\n"): - if not (line := line.strip()): continue - match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE) - if match: - speaker_id = int(match.group(1)) - if speaker_id < 1: - logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'") - continue - text = ' ' + match.group(2).strip() - # Internally, the model expects 0-based indexing for speakers - internal_speaker_id = speaker_id - 1 - parsed_lines.append((internal_speaker_id, text)) - if speaker_id not in speaker_ids_in_script: - speaker_ids_in_script.append(speaker_id) - else: - logger.warning(f"Could not parse line, skipping: '{line}'") - return parsed_lines, sorted(list(set(speaker_ids_in_script))) - -def preprocess_comfy_audio(audio_dict: dict, target_sr: int = 24000) -> np.ndarray: - """ - Converts a ComfyUI AUDIO dict to a mono NumPy array, resampling if necessary. - """ - if not audio_dict: return None - waveform_tensor = audio_dict.get('waveform') - if waveform_tensor is None or waveform_tensor.numel() == 0: return None - - waveform = waveform_tensor[0].cpu().numpy() - original_sr = audio_dict['sample_rate'] - - if waveform.ndim > 1: - waveform = np.mean(waveform, axis=0) - - # Check for invalid values - if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): - logger.error("Audio contains NaN or Inf values, replacing with zeros") - waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) - - # Ensure audio is not completely silent or has extreme values - if np.all(waveform == 0): - logger.warning("Audio waveform is completely silent") - - # Normalize extreme values - max_val = np.abs(waveform).max() - if max_val > 10.0: - logger.warning(f"Audio values are very large (max: {max_val}), normalizing") - waveform = waveform / max_val - - if original_sr != target_sr: - if librosa is None: - raise ImportError("`librosa` package is required for audio resampling. Please install it with `pip install librosa`.") - logger.warning(f"Resampling reference audio from {original_sr}Hz to {target_sr}Hz.") - waveform = librosa.resample(y=waveform, orig_sr=original_sr, target_sr=target_sr) - - # Final check after resampling - if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): - logger.error("Audio contains NaN or Inf after resampling, replacing with zeros") - waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) - - return waveform.astype(np.float32) - -def check_for_interrupt(): - try: - throw_exception_if_processing_interrupted() - return False - except: - return True - class VibeVoiceTTSNode: @classmethod def INPUT_TYPES(cls): + model_names = list(AVAILABLE_VIBEVOICE_MODELS.keys()) + if not model_names: + model_names.append("No models found in models/tts/VibeVoice") + return { "required": { - "model_name": (list(MODEL_CONFIGS.keys()), { - "tooltip": "Select the VibeVoice model to use. Models will be downloaded automatically if not present." + "model_name": (model_names, { + "tooltip": "Select the VibeVoice model to use. Official models will be downloaded automatically." }), "text": ("STRING", { "multiline": True, @@ -405,7 +39,7 @@ class VibeVoiceTTSNode: "tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest), Sage (quantized)" }), "cfg_scale": ("FLOAT", { - "default": 1.3, "min": 1.0, "max": 3.0, "step": 0.05, + "default": 1.3, "min": 1.0, "max": 10.0, "step": 0.05, "tooltip": "Classifier-Free Guidance scale. Higher values increase adherence to the voice prompt but may reduce naturalness. Recommended: 1.3" }), "inference_steps": ("INT", { @@ -450,16 +84,13 @@ class VibeVoiceTTSNode: CATEGORY = "audio/tts" def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, force_offload, **kwargs): - actual_attention_mode = attention_mode if quantize_llm_4bit and attention_mode in ["eager", "flash_attention_2"]: actual_attention_mode = "sdpa" cache_key = f"{model_name}_attn_{actual_attention_mode}_q4_{int(quantize_llm_4bit)}" - # Clean up old models when switching to a different model if cache_key not in VIBEVOICE_PATCHER_CACHE: - # Only keep models that are currently being requested cleanup_old_models(keep_cache_key=cache_key) model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit) @@ -501,7 +132,6 @@ class VibeVoiceTTSNode: return_tensors="pt", return_attention_mask=True ) - # Validate inputs before moving to GPU for key, value in inputs.items(): if isinstance(value, torch.Tensor): if torch.any(torch.isnan(value)) or torch.any(torch.isinf(value)): @@ -519,9 +149,6 @@ class VibeVoiceTTSNode: if top_k > 0: generation_config['top_k'] = top_k - # cause float() error for q4+eager - # model = model.float() IS REMOVED - with torch.no_grad(): pbar = ProgressBar(inference_steps) @@ -578,4 +205,4 @@ class VibeVoiceTTSNode: return ({"waveform": output_waveform.detach().cpu(), "sample_rate": 24000},) NODE_CLASS_MAPPINGS = {"VibeVoiceTTS": VibeVoiceTTSNode} -NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"} +NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"} \ No newline at end of file