Merge pull request #12 from Shadowfita/patch-1

Add optional Q4 (4-bit) LLM quantization for VibeVoice
This commit is contained in:
WildAi
2025-08-31 20:53:05 +03:00
committed by GitHub

View File

@@ -14,7 +14,7 @@ import comfy.model_patcher
from comfy.utils import ProgressBar
from comfy.model_management import throw_exception_if_processing_interrupted
from transformers import set_seed, AutoTokenizer
from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
@@ -80,17 +80,18 @@ def cleanup_old_models(keep_cache_key=None):
class VibeVoiceModelHandler(torch.nn.Module):
"""A torch.nn.Module wrapper to hold the VibeVoice model and processor."""
def __init__(self, model_pack_name, attention_mode="eager"):
def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False):
super().__init__()
self.model_pack_name = model_pack_name
self.attention_mode = attention_mode
self.use_llm_4bit = use_llm_4bit
self.cache_key = f"{model_pack_name}_attn_{attention_mode}"
self.model = None
self.processor = None
self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3))
def load_model(self, device, attention_mode="eager"):
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device , attention_mode)
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit)
self.model.to(device)
class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
@@ -180,11 +181,13 @@ class VibeVoiceLoader:
return attention_mode
@staticmethod
def load_model(model_name: str, device, attention_mode: str = "eager"):
def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False):
# Validate attention mode
if attention_mode not in ATTENTION_MODES:
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
attention_mode = "eager"
if use_llm_4bit and attention_mode == "flash_attention_2":
attention_mode = "sdpa"
# Create cache key that includes attention mode
cache_key = f"{model_name}_attn_{attention_mode}"
@@ -214,6 +217,16 @@ class VibeVoiceLoader:
final_attention_mode = VibeVoiceLoader._check_attention_compatibility(
attention_mode, torch_dtype, device_name
)
# Build optional 4-bit config (LLM only)
quant_config = None
if use_llm_4bit:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
logger.info(f"Requested attention mode: {attention_mode}")
if final_attention_mode != attention_mode:
@@ -253,11 +266,13 @@ class VibeVoiceLoader:
try:
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
model_path,
torch_dtype=torch_dtype,
torch_dtype=torch.bfloat16 if quant_config else torch_dtype,
attn_implementation=final_attention_mode,
device_map=device
device_map="auto" if quant_config else device,
quantization_config=quant_config, # <- forwarded if supported
)
model.eval()
setattr(model, "_llm_4bit", bool(quant_config))
# Store with the actual attention mode used (not the requested one)
LOADED_MODELS[cache_key] = (model, processor)
@@ -381,6 +396,10 @@ class VibeVoiceTTSNode:
"default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.",
"tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line."
}),
"quantize_llm_4bit": ("BOOLEAN", {
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision",
"tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32."
}),
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {
"default": "sdpa",
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)"
@@ -426,20 +445,20 @@ class VibeVoiceTTSNode:
FUNCTION = "generate_audio"
CATEGORY = "audio/tts"
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs):
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, **kwargs):
if not text.strip():
logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.")
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
# Create cache key that includes attention mode
cache_key = f"{model_name}_attn_{attention_mode}"
cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(quantize_llm_4bit)}"
# Clean up old models when switching to a different model
if cache_key not in VIBEVOICE_PATCHER_CACHE:
# Only keep models that are currently being requested
cleanup_old_models(keep_cache_key=cache_key)
model_handler = VibeVoiceModelHandler(model_name, attention_mode)
model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit)
patcher = VibeVoicePatcher(
model_handler,
attention_mode=attention_mode,