Merge pull request #12 from Shadowfita/patch-1

Add optional Q4 (4-bit) LLM quantization for VibeVoice
This commit is contained in:
WildAi
2025-08-31 20:53:05 +03:00
committed by GitHub

View File

@@ -14,7 +14,7 @@ import comfy.model_patcher
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from comfy.model_management import throw_exception_if_processing_interrupted from comfy.model_management import throw_exception_if_processing_interrupted
from transformers import set_seed, AutoTokenizer from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
@@ -80,17 +80,18 @@ def cleanup_old_models(keep_cache_key=None):
class VibeVoiceModelHandler(torch.nn.Module): class VibeVoiceModelHandler(torch.nn.Module):
"""A torch.nn.Module wrapper to hold the VibeVoice model and processor.""" """A torch.nn.Module wrapper to hold the VibeVoice model and processor."""
def __init__(self, model_pack_name, attention_mode="eager"): def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False):
super().__init__() super().__init__()
self.model_pack_name = model_pack_name self.model_pack_name = model_pack_name
self.attention_mode = attention_mode self.attention_mode = attention_mode
self.use_llm_4bit = use_llm_4bit
self.cache_key = f"{model_pack_name}_attn_{attention_mode}" self.cache_key = f"{model_pack_name}_attn_{attention_mode}"
self.model = None self.model = None
self.processor = None self.processor = None
self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3)) self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3))
def load_model(self, device, attention_mode="eager"): def load_model(self, device, attention_mode="eager"):
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device , attention_mode) self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit)
self.model.to(device) self.model.to(device)
class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
@@ -180,11 +181,13 @@ class VibeVoiceLoader:
return attention_mode return attention_mode
@staticmethod @staticmethod
def load_model(model_name: str, device, attention_mode: str = "eager"): def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False):
# Validate attention mode # Validate attention mode
if attention_mode not in ATTENTION_MODES: if attention_mode not in ATTENTION_MODES:
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager") logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
attention_mode = "eager" attention_mode = "eager"
if use_llm_4bit and attention_mode == "flash_attention_2":
attention_mode = "sdpa"
# Create cache key that includes attention mode # Create cache key that includes attention mode
cache_key = f"{model_name}_attn_{attention_mode}" cache_key = f"{model_name}_attn_{attention_mode}"
@@ -214,6 +217,16 @@ class VibeVoiceLoader:
final_attention_mode = VibeVoiceLoader._check_attention_compatibility( final_attention_mode = VibeVoiceLoader._check_attention_compatibility(
attention_mode, torch_dtype, device_name attention_mode, torch_dtype, device_name
) )
# Build optional 4-bit config (LLM only)
quant_config = None
if use_llm_4bit:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
logger.info(f"Requested attention mode: {attention_mode}") logger.info(f"Requested attention mode: {attention_mode}")
if final_attention_mode != attention_mode: if final_attention_mode != attention_mode:
@@ -253,11 +266,13 @@ class VibeVoiceLoader:
try: try:
model = VibeVoiceForConditionalGenerationInference.from_pretrained( model = VibeVoiceForConditionalGenerationInference.from_pretrained(
model_path, model_path,
torch_dtype=torch_dtype, torch_dtype=torch.bfloat16 if quant_config else torch_dtype,
attn_implementation=final_attention_mode, attn_implementation=final_attention_mode,
device_map=device device_map="auto" if quant_config else device,
quantization_config=quant_config, # <- forwarded if supported
) )
model.eval() model.eval()
setattr(model, "_llm_4bit", bool(quant_config))
# Store with the actual attention mode used (not the requested one) # Store with the actual attention mode used (not the requested one)
LOADED_MODELS[cache_key] = (model, processor) LOADED_MODELS[cache_key] = (model, processor)
@@ -381,6 +396,10 @@ class VibeVoiceTTSNode:
"default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.", "default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.",
"tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line." "tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line."
}), }),
"quantize_llm_4bit": ("BOOLEAN", {
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision",
"tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32."
}),
"attention_mode": (["eager", "sdpa", "flash_attention_2"], { "attention_mode": (["eager", "sdpa", "flash_attention_2"], {
"default": "sdpa", "default": "sdpa",
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)" "tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)"
@@ -426,20 +445,20 @@ class VibeVoiceTTSNode:
FUNCTION = "generate_audio" FUNCTION = "generate_audio"
CATEGORY = "audio/tts" CATEGORY = "audio/tts"
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs): def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, **kwargs):
if not text.strip(): if not text.strip():
logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.") logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.")
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
# Create cache key that includes attention mode # Create cache key that includes attention mode
cache_key = f"{model_name}_attn_{attention_mode}" cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(quantize_llm_4bit)}"
# Clean up old models when switching to a different model # Clean up old models when switching to a different model
if cache_key not in VIBEVOICE_PATCHER_CACHE: if cache_key not in VIBEVOICE_PATCHER_CACHE:
# Only keep models that are currently being requested # Only keep models that are currently being requested
cleanup_old_models(keep_cache_key=cache_key) cleanup_old_models(keep_cache_key=cache_key)
model_handler = VibeVoiceModelHandler(model_name, attention_mode) model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit)
patcher = VibeVoicePatcher( patcher = VibeVoicePatcher(
model_handler, model_handler,
attention_mode=attention_mode, attention_mode=attention_mode,