From 7419fcd66ff1bd11fe5ba54ad3b419affde0dac8 Mon Sep 17 00:00:00 2001 From: Orion Date: Sat, 30 Aug 2025 23:48:25 +1000 Subject: [PATCH] Add optional Q4 (4-bit) LLM quantization for VibeVoice MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces an **optional 4-bit (NF4) quantization path** for the **Qwen2.5 LLM component** inside VibeVoice, using Transformers + bitsandbytes. The diffusion head and processors remain BF16/FP32. This mirrors the project’s architecture and enables the **7B preview** to run on smaller GPUs while preserving output quality. **Changes / Additions:** * New toggle to run the LLM in **4-bit NF4** via `BitsAndBytesConfig`; default remains full precision. * Q4 prefers **SDPA** attention (Flash-Attn auto-downshifts) for stability. **Improvements on my 3080 12GB:** * **7B** * **Timing:** 29m 27s → **203.47s** (\~3m 23s) — **−88.5% time** (\~**8.68× faster**) * **VRAM:** Q4 ≈ **7.6 GB**; FP16 ≳ **12 GB** — Q4 saves **≥4.4 GB** (≥**36.7%**) * **1.5B** * **Timing (Q4):** 105s → **154s** — **+49s** (\~**1.47× slower**) * **VRAM:** Q4 ≈ **3.2 GB**; FP16 ≈ **8.7 GB** — Q4 saves **\~5.5 GB** (\~**63.2%**) These changes have resulted in a nearly 90% reduction in inference time and over 40% reduction in VRAM usage with the 7B model in VRAM constrained environments with no perceptible change in quality in my limited testing. While there is an increase in inference time with the 1.5B model, some may consider the smaller VRAM footprint worth it. --- vibevoice_nodes.py | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/vibevoice_nodes.py b/vibevoice_nodes.py index 1b75e2d..0f2fe33 100644 --- a/vibevoice_nodes.py +++ b/vibevoice_nodes.py @@ -14,7 +14,7 @@ import comfy.model_patcher from comfy.utils import ProgressBar from comfy.model_management import throw_exception_if_processing_interrupted -from transformers import set_seed, AutoTokenizer +from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor @@ -80,17 +80,18 @@ def cleanup_old_models(keep_cache_key=None): class VibeVoiceModelHandler(torch.nn.Module): """A torch.nn.Module wrapper to hold the VibeVoice model and processor.""" - def __init__(self, model_pack_name, attention_mode="eager"): + def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False): super().__init__() self.model_pack_name = model_pack_name self.attention_mode = attention_mode + self.use_llm_4bit = use_llm_4bit self.cache_key = f"{model_pack_name}_attn_{attention_mode}" self.model = None self.processor = None self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3)) def load_model(self, device, attention_mode="eager"): - self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device , attention_mode) + self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit) self.model.to(device) class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): @@ -180,11 +181,13 @@ class VibeVoiceLoader: return attention_mode @staticmethod - def load_model(model_name: str, device, attention_mode: str = "eager"): + def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False): # Validate attention mode if attention_mode not in ATTENTION_MODES: logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager") attention_mode = "eager" + if use_llm_4bit and attention_mode == "flash_attention_2": + attention_mode = "sdpa" # Create cache key that includes attention mode cache_key = f"{model_name}_attn_{attention_mode}" @@ -214,6 +217,16 @@ class VibeVoiceLoader: final_attention_mode = VibeVoiceLoader._check_attention_compatibility( attention_mode, torch_dtype, device_name ) + + # Build optional 4-bit config (LLM only) + quant_config = None + if use_llm_4bit: + quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) logger.info(f"Requested attention mode: {attention_mode}") if final_attention_mode != attention_mode: @@ -253,11 +266,13 @@ class VibeVoiceLoader: try: model = VibeVoiceForConditionalGenerationInference.from_pretrained( model_path, - torch_dtype=torch_dtype, + torch_dtype=torch.bfloat16 if quant_config else torch_dtype, attn_implementation=final_attention_mode, - device_map=device + device_map="auto" if quant_config else device, + quantization_config=quant_config, # <- forwarded if supported ) model.eval() + setattr(model, "_llm_4bit", bool(quant_config)) # Store with the actual attention mode used (not the requested one) LOADED_MODELS[cache_key] = (model, processor) @@ -381,6 +396,10 @@ class VibeVoiceTTSNode: "default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.", "tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line." }), + "quantize_llm_4bit": ("BOOLEAN", { + "default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision", + "tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32." + }), "attention_mode": (["eager", "sdpa", "flash_attention_2"], { "default": "sdpa", "tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)" @@ -426,20 +445,20 @@ class VibeVoiceTTSNode: FUNCTION = "generate_audio" CATEGORY = "audio/tts" - def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs): + def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, **kwargs): if not text.strip(): logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.") return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) # Create cache key that includes attention mode - cache_key = f"{model_name}_attn_{attention_mode}" + cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(quantize_llm_4bit)}" # Clean up old models when switching to a different model if cache_key not in VIBEVOICE_PATCHER_CACHE: # Only keep models that are currently being requested cleanup_old_models(keep_cache_key=cache_key) - model_handler = VibeVoiceModelHandler(model_name, attention_mode) + model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit) patcher = VibeVoicePatcher( model_handler, attention_mode=attention_mode,