mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-01-26 22:49:46 +00:00
Merge pull request #12 from Shadowfita/patch-1
Add optional Q4 (4-bit) LLM quantization for VibeVoice
This commit is contained in:
@@ -14,7 +14,7 @@ import comfy.model_patcher
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy.model_management import throw_exception_if_processing_interrupted
|
||||
|
||||
from transformers import set_seed, AutoTokenizer
|
||||
from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
|
||||
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
||||
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
||||
from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
||||
@@ -80,17 +80,18 @@ def cleanup_old_models(keep_cache_key=None):
|
||||
|
||||
class VibeVoiceModelHandler(torch.nn.Module):
|
||||
"""A torch.nn.Module wrapper to hold the VibeVoice model and processor."""
|
||||
def __init__(self, model_pack_name, attention_mode="eager"):
|
||||
def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False):
|
||||
super().__init__()
|
||||
self.model_pack_name = model_pack_name
|
||||
self.attention_mode = attention_mode
|
||||
self.use_llm_4bit = use_llm_4bit
|
||||
self.cache_key = f"{model_pack_name}_attn_{attention_mode}"
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3))
|
||||
|
||||
def load_model(self, device, attention_mode="eager"):
|
||||
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device , attention_mode)
|
||||
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit)
|
||||
self.model.to(device)
|
||||
|
||||
class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
|
||||
@@ -180,11 +181,13 @@ class VibeVoiceLoader:
|
||||
return attention_mode
|
||||
|
||||
@staticmethod
|
||||
def load_model(model_name: str, device, attention_mode: str = "eager"):
|
||||
def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False):
|
||||
# Validate attention mode
|
||||
if attention_mode not in ATTENTION_MODES:
|
||||
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
|
||||
attention_mode = "eager"
|
||||
if use_llm_4bit and attention_mode == "flash_attention_2":
|
||||
attention_mode = "sdpa"
|
||||
|
||||
# Create cache key that includes attention mode
|
||||
cache_key = f"{model_name}_attn_{attention_mode}"
|
||||
@@ -214,6 +217,16 @@ class VibeVoiceLoader:
|
||||
final_attention_mode = VibeVoiceLoader._check_attention_compatibility(
|
||||
attention_mode, torch_dtype, device_name
|
||||
)
|
||||
|
||||
# Build optional 4-bit config (LLM only)
|
||||
quant_config = None
|
||||
if use_llm_4bit:
|
||||
quant_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
logger.info(f"Requested attention mode: {attention_mode}")
|
||||
if final_attention_mode != attention_mode:
|
||||
@@ -253,11 +266,13 @@ class VibeVoiceLoader:
|
||||
try:
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=torch.bfloat16 if quant_config else torch_dtype,
|
||||
attn_implementation=final_attention_mode,
|
||||
device_map=device
|
||||
device_map="auto" if quant_config else device,
|
||||
quantization_config=quant_config, # <- forwarded if supported
|
||||
)
|
||||
model.eval()
|
||||
setattr(model, "_llm_4bit", bool(quant_config))
|
||||
|
||||
# Store with the actual attention mode used (not the requested one)
|
||||
LOADED_MODELS[cache_key] = (model, processor)
|
||||
@@ -381,6 +396,10 @@ class VibeVoiceTTSNode:
|
||||
"default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.",
|
||||
"tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line."
|
||||
}),
|
||||
"quantize_llm_4bit": ("BOOLEAN", {
|
||||
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision",
|
||||
"tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32."
|
||||
}),
|
||||
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {
|
||||
"default": "sdpa",
|
||||
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)"
|
||||
@@ -426,20 +445,20 @@ class VibeVoiceTTSNode:
|
||||
FUNCTION = "generate_audio"
|
||||
CATEGORY = "audio/tts"
|
||||
|
||||
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs):
|
||||
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, **kwargs):
|
||||
if not text.strip():
|
||||
logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.")
|
||||
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
||||
|
||||
# Create cache key that includes attention mode
|
||||
cache_key = f"{model_name}_attn_{attention_mode}"
|
||||
cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(quantize_llm_4bit)}"
|
||||
|
||||
# Clean up old models when switching to a different model
|
||||
if cache_key not in VIBEVOICE_PATCHER_CACHE:
|
||||
# Only keep models that are currently being requested
|
||||
cleanup_old_models(keep_cache_key=cache_key)
|
||||
|
||||
model_handler = VibeVoiceModelHandler(model_name, attention_mode)
|
||||
model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit)
|
||||
patcher = VibeVoicePatcher(
|
||||
model_handler,
|
||||
attention_mode=attention_mode,
|
||||
|
||||
Reference in New Issue
Block a user