Add optional Q4 (4-bit) LLM quantization for VibeVoice

This PR introduces an **optional 4-bit (NF4) quantization path** for the **Qwen2.5 LLM component** inside VibeVoice, using Transformers + bitsandbytes. The diffusion head and processors remain BF16/FP32. This mirrors the project’s architecture and enables the **7B preview** to run on smaller GPUs while preserving output quality. 

**Changes / Additions:**

* New toggle to run the LLM in **4-bit NF4** via `BitsAndBytesConfig`; default remains full precision.
* Q4 prefers **SDPA** attention (Flash-Attn auto-downshifts) for stability.

**Improvements on my 3080 12GB:**

* **7B**
    * **Timing:** 29m 27s → **203.47s** (\~3m 23s) — **−88.5% time** (\~**8.68× faster**)
    * **VRAM:** Q4 ≈ **7.6 GB**; FP16 ≳ **12 GB** — Q4 saves **≥4.4 GB** (≥**36.7%**)
* **1.5B**
    * **Timing (Q4):** 105s → **154s** — **+49s** (\~**1.47× slower**)
    * **VRAM:** Q4 ≈ **3.2 GB**; FP16 ≈ **8.7 GB** — Q4 saves **\~5.5 GB** (\~**63.2%**)

These changes have resulted in a nearly 90% reduction in inference time and over 40% reduction in VRAM usage with the 7B model in VRAM constrained environments with no perceptible change in quality in my limited testing.

While there is an increase in inference time with the 1.5B model, some may consider the smaller VRAM footprint worth it.
This commit is contained in:
Orion
2025-08-30 23:48:25 +10:00
committed by GitHub
parent 20baa02e9a
commit 7419fcd66f

View File

@@ -14,7 +14,7 @@ import comfy.model_patcher
from comfy.utils import ProgressBar
from comfy.model_management import throw_exception_if_processing_interrupted
from transformers import set_seed, AutoTokenizer
from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
@@ -80,17 +80,18 @@ def cleanup_old_models(keep_cache_key=None):
class VibeVoiceModelHandler(torch.nn.Module):
"""A torch.nn.Module wrapper to hold the VibeVoice model and processor."""
def __init__(self, model_pack_name, attention_mode="eager"):
def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False):
super().__init__()
self.model_pack_name = model_pack_name
self.attention_mode = attention_mode
self.use_llm_4bit = use_llm_4bit
self.cache_key = f"{model_pack_name}_attn_{attention_mode}"
self.model = None
self.processor = None
self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3))
def load_model(self, device, attention_mode="eager"):
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device , attention_mode)
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit)
self.model.to(device)
class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
@@ -180,11 +181,13 @@ class VibeVoiceLoader:
return attention_mode
@staticmethod
def load_model(model_name: str, device, attention_mode: str = "eager"):
def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False):
# Validate attention mode
if attention_mode not in ATTENTION_MODES:
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
attention_mode = "eager"
if use_llm_4bit and attention_mode == "flash_attention_2":
attention_mode = "sdpa"
# Create cache key that includes attention mode
cache_key = f"{model_name}_attn_{attention_mode}"
@@ -214,6 +217,16 @@ class VibeVoiceLoader:
final_attention_mode = VibeVoiceLoader._check_attention_compatibility(
attention_mode, torch_dtype, device_name
)
# Build optional 4-bit config (LLM only)
quant_config = None
if use_llm_4bit:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
logger.info(f"Requested attention mode: {attention_mode}")
if final_attention_mode != attention_mode:
@@ -253,11 +266,13 @@ class VibeVoiceLoader:
try:
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
model_path,
torch_dtype=torch_dtype,
torch_dtype=torch.bfloat16 if quant_config else torch_dtype,
attn_implementation=final_attention_mode,
device_map=device
device_map="auto" if quant_config else device,
quantization_config=quant_config, # <- forwarded if supported
)
model.eval()
setattr(model, "_llm_4bit", bool(quant_config))
# Store with the actual attention mode used (not the requested one)
LOADED_MODELS[cache_key] = (model, processor)
@@ -381,6 +396,10 @@ class VibeVoiceTTSNode:
"default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.",
"tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line."
}),
"quantize_llm_4bit": ("BOOLEAN", {
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision",
"tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32."
}),
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {
"default": "sdpa",
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)"
@@ -426,20 +445,20 @@ class VibeVoiceTTSNode:
FUNCTION = "generate_audio"
CATEGORY = "audio/tts"
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs):
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, **kwargs):
if not text.strip():
logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.")
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
# Create cache key that includes attention mode
cache_key = f"{model_name}_attn_{attention_mode}"
cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(quantize_llm_4bit)}"
# Clean up old models when switching to a different model
if cache_key not in VIBEVOICE_PATCHER_CACHE:
# Only keep models that are currently being requested
cleanup_old_models(keep_cache_key=cache_key)
model_handler = VibeVoiceModelHandler(model_name, attention_mode)
model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit)
patcher = VibeVoicePatcher(
model_handler,
attention_mode=attention_mode,