diff --git a/vibevoice_nodes.py b/vibevoice_nodes.py index 1b75e2d..0f2fe33 100644 --- a/vibevoice_nodes.py +++ b/vibevoice_nodes.py @@ -14,7 +14,7 @@ import comfy.model_patcher from comfy.utils import ProgressBar from comfy.model_management import throw_exception_if_processing_interrupted -from transformers import set_seed, AutoTokenizer +from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor @@ -80,17 +80,18 @@ def cleanup_old_models(keep_cache_key=None): class VibeVoiceModelHandler(torch.nn.Module): """A torch.nn.Module wrapper to hold the VibeVoice model and processor.""" - def __init__(self, model_pack_name, attention_mode="eager"): + def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False): super().__init__() self.model_pack_name = model_pack_name self.attention_mode = attention_mode + self.use_llm_4bit = use_llm_4bit self.cache_key = f"{model_pack_name}_attn_{attention_mode}" self.model = None self.processor = None self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3)) def load_model(self, device, attention_mode="eager"): - self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device , attention_mode) + self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit) self.model.to(device) class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): @@ -180,11 +181,13 @@ class VibeVoiceLoader: return attention_mode @staticmethod - def load_model(model_name: str, device, attention_mode: str = "eager"): + def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False): # Validate attention mode if attention_mode not in ATTENTION_MODES: logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager") attention_mode = "eager" + if use_llm_4bit and attention_mode == "flash_attention_2": + attention_mode = "sdpa" # Create cache key that includes attention mode cache_key = f"{model_name}_attn_{attention_mode}" @@ -214,6 +217,16 @@ class VibeVoiceLoader: final_attention_mode = VibeVoiceLoader._check_attention_compatibility( attention_mode, torch_dtype, device_name ) + + # Build optional 4-bit config (LLM only) + quant_config = None + if use_llm_4bit: + quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) logger.info(f"Requested attention mode: {attention_mode}") if final_attention_mode != attention_mode: @@ -253,11 +266,13 @@ class VibeVoiceLoader: try: model = VibeVoiceForConditionalGenerationInference.from_pretrained( model_path, - torch_dtype=torch_dtype, + torch_dtype=torch.bfloat16 if quant_config else torch_dtype, attn_implementation=final_attention_mode, - device_map=device + device_map="auto" if quant_config else device, + quantization_config=quant_config, # <- forwarded if supported ) model.eval() + setattr(model, "_llm_4bit", bool(quant_config)) # Store with the actual attention mode used (not the requested one) LOADED_MODELS[cache_key] = (model, processor) @@ -381,6 +396,10 @@ class VibeVoiceTTSNode: "default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.", "tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line." }), + "quantize_llm_4bit": ("BOOLEAN", { + "default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision", + "tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32." + }), "attention_mode": (["eager", "sdpa", "flash_attention_2"], { "default": "sdpa", "tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)" @@ -426,20 +445,20 @@ class VibeVoiceTTSNode: FUNCTION = "generate_audio" CATEGORY = "audio/tts" - def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs): + def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, **kwargs): if not text.strip(): logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.") return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) # Create cache key that includes attention mode - cache_key = f"{model_name}_attn_{attention_mode}" + cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(quantize_llm_4bit)}" # Clean up old models when switching to a different model if cache_key not in VIBEVOICE_PATCHER_CACHE: # Only keep models that are currently being requested cleanup_old_models(keep_cache_key=cache_key) - model_handler = VibeVoiceModelHandler(model_name, attention_mode) + model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit) patcher = VibeVoicePatcher( model_handler, attention_mode=attention_mode,