mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-04-30 11:41:35 +00:00
Merge pull request #12 from Shadowfita/patch-1
Add optional Q4 (4-bit) LLM quantization for VibeVoice
This commit is contained in:
@@ -14,7 +14,7 @@ import comfy.model_patcher
|
|||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
from comfy.model_management import throw_exception_if_processing_interrupted
|
from comfy.model_management import throw_exception_if_processing_interrupted
|
||||||
|
|
||||||
from transformers import set_seed, AutoTokenizer
|
from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
|
||||||
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
||||||
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
||||||
from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
||||||
@@ -80,17 +80,18 @@ def cleanup_old_models(keep_cache_key=None):
|
|||||||
|
|
||||||
class VibeVoiceModelHandler(torch.nn.Module):
|
class VibeVoiceModelHandler(torch.nn.Module):
|
||||||
"""A torch.nn.Module wrapper to hold the VibeVoice model and processor."""
|
"""A torch.nn.Module wrapper to hold the VibeVoice model and processor."""
|
||||||
def __init__(self, model_pack_name, attention_mode="eager"):
|
def __init__(self, model_pack_name, attention_mode="eager", use_llm_4bit=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_pack_name = model_pack_name
|
self.model_pack_name = model_pack_name
|
||||||
self.attention_mode = attention_mode
|
self.attention_mode = attention_mode
|
||||||
|
self.use_llm_4bit = use_llm_4bit
|
||||||
self.cache_key = f"{model_pack_name}_attn_{attention_mode}"
|
self.cache_key = f"{model_pack_name}_attn_{attention_mode}"
|
||||||
self.model = None
|
self.model = None
|
||||||
self.processor = None
|
self.processor = None
|
||||||
self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3))
|
self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3))
|
||||||
|
|
||||||
def load_model(self, device, attention_mode="eager"):
|
def load_model(self, device, attention_mode="eager"):
|
||||||
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device , attention_mode)
|
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, device, attention_mode, use_llm_4bit=self.use_llm_4bit)
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
|
|
||||||
class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
|
class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
|
||||||
@@ -180,11 +181,13 @@ class VibeVoiceLoader:
|
|||||||
return attention_mode
|
return attention_mode
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_model(model_name: str, device, attention_mode: str = "eager"):
|
def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False):
|
||||||
# Validate attention mode
|
# Validate attention mode
|
||||||
if attention_mode not in ATTENTION_MODES:
|
if attention_mode not in ATTENTION_MODES:
|
||||||
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
|
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
|
||||||
attention_mode = "eager"
|
attention_mode = "eager"
|
||||||
|
if use_llm_4bit and attention_mode == "flash_attention_2":
|
||||||
|
attention_mode = "sdpa"
|
||||||
|
|
||||||
# Create cache key that includes attention mode
|
# Create cache key that includes attention mode
|
||||||
cache_key = f"{model_name}_attn_{attention_mode}"
|
cache_key = f"{model_name}_attn_{attention_mode}"
|
||||||
@@ -214,6 +217,16 @@ class VibeVoiceLoader:
|
|||||||
final_attention_mode = VibeVoiceLoader._check_attention_compatibility(
|
final_attention_mode = VibeVoiceLoader._check_attention_compatibility(
|
||||||
attention_mode, torch_dtype, device_name
|
attention_mode, torch_dtype, device_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build optional 4-bit config (LLM only)
|
||||||
|
quant_config = None
|
||||||
|
if use_llm_4bit:
|
||||||
|
quant_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Requested attention mode: {attention_mode}")
|
logger.info(f"Requested attention mode: {attention_mode}")
|
||||||
if final_attention_mode != attention_mode:
|
if final_attention_mode != attention_mode:
|
||||||
@@ -253,11 +266,13 @@ class VibeVoiceLoader:
|
|||||||
try:
|
try:
|
||||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch.bfloat16 if quant_config else torch_dtype,
|
||||||
attn_implementation=final_attention_mode,
|
attn_implementation=final_attention_mode,
|
||||||
device_map=device
|
device_map="auto" if quant_config else device,
|
||||||
|
quantization_config=quant_config, # <- forwarded if supported
|
||||||
)
|
)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
setattr(model, "_llm_4bit", bool(quant_config))
|
||||||
|
|
||||||
# Store with the actual attention mode used (not the requested one)
|
# Store with the actual attention mode used (not the requested one)
|
||||||
LOADED_MODELS[cache_key] = (model, processor)
|
LOADED_MODELS[cache_key] = (model, processor)
|
||||||
@@ -381,6 +396,10 @@ class VibeVoiceTTSNode:
|
|||||||
"default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.",
|
"default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.",
|
||||||
"tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line."
|
"tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line."
|
||||||
}),
|
}),
|
||||||
|
"quantize_llm_4bit": ("BOOLEAN", {
|
||||||
|
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision",
|
||||||
|
"tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32."
|
||||||
|
}),
|
||||||
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {
|
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {
|
||||||
"default": "sdpa",
|
"default": "sdpa",
|
||||||
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)"
|
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)"
|
||||||
@@ -426,20 +445,20 @@ class VibeVoiceTTSNode:
|
|||||||
FUNCTION = "generate_audio"
|
FUNCTION = "generate_audio"
|
||||||
CATEGORY = "audio/tts"
|
CATEGORY = "audio/tts"
|
||||||
|
|
||||||
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs):
|
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, **kwargs):
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.")
|
logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.")
|
||||||
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
||||||
|
|
||||||
# Create cache key that includes attention mode
|
# Create cache key that includes attention mode
|
||||||
cache_key = f"{model_name}_attn_{attention_mode}"
|
cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(quantize_llm_4bit)}"
|
||||||
|
|
||||||
# Clean up old models when switching to a different model
|
# Clean up old models when switching to a different model
|
||||||
if cache_key not in VIBEVOICE_PATCHER_CACHE:
|
if cache_key not in VIBEVOICE_PATCHER_CACHE:
|
||||||
# Only keep models that are currently being requested
|
# Only keep models that are currently being requested
|
||||||
cleanup_old_models(keep_cache_key=cache_key)
|
cleanup_old_models(keep_cache_key=cache_key)
|
||||||
|
|
||||||
model_handler = VibeVoiceModelHandler(model_name, attention_mode)
|
model_handler = VibeVoiceModelHandler(model_name, attention_mode, use_llm_4bit=quantize_llm_4bit)
|
||||||
patcher = VibeVoicePatcher(
|
patcher = VibeVoicePatcher(
|
||||||
model_handler,
|
model_handler,
|
||||||
attention_mode=attention_mode,
|
attention_mode=attention_mode,
|
||||||
|
|||||||
Reference in New Issue
Block a user