SageAttention support, fixes

This commit is contained in:
WildAi
2025-09-03 11:42:43 +03:00
parent 2aa03a8254
commit 52cee71368
6 changed files with 185 additions and 228 deletions

View File

@@ -2,7 +2,12 @@ import os
import sys
import logging
# allowing absolute imports like 'from vibevoice.modular...' to work.
try:
import sageattention
SAGE_ATTENTION_AVAILABLE = True
except ImportError:
SAGE_ATTENTION_AVAILABLE = False
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.append(current_dir)
@@ -13,7 +18,7 @@ from .vibevoice_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
# Configure a logger for the entire custom node package
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
logger.setLevel(logging.INFO)
logger.propagate = False
if not logger.hasHandlers():

View File

@@ -47,77 +47,6 @@
null
]
},
{
"id": 11,
"type": "VibeVoiceTTS",
"pos": [
-1570,
-1130
],
"size": [
460,
510
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "speaker_1_voice",
"shape": 7,
"type": "AUDIO",
"link": 28
},
{
"name": "speaker_2_voice",
"shape": 7,
"type": "AUDIO",
"link": 29
},
{
"name": "speaker_3_voice",
"shape": 7,
"type": "AUDIO",
"link": null
},
{
"name": "speaker_4_voice",
"shape": 7,
"type": "AUDIO",
"link": null
}
],
"outputs": [
{
"name": "AUDIO",
"type": "AUDIO",
"links": [
27
]
}
],
"properties": {
"cnr_id": "ComfyUI-VibeVoice",
"ver": "37803a884fb8f9b43c38286f6d654c7f97181a73",
"Node name for S&R": "VibeVoiceTTS"
},
"widgets_values": [
"VibeVoice-1.5B",
"Speaker 1: I can't believe you did it again. I waited for two hours. Two hours! Not a single call, not a text. Do you have any idea how embarrassing that was, just sitting there alone?\nSpeaker 2: Look, I know, I'm sorry, alright? Work was a complete nightmare. My boss dropped a critical deadline on me at the last minute. I didn't even have a second to breathe, let alone check my phone.\nSpeaker 1: A nightmare? That's the same excuse you used last time. I'm starting to think you just don't care. It's easier to say 'work was crazy' than to just admit that I'm not a priority for you anymore.",
false,
"sdpa",
1.3,
10,
56109085141530,
"randomize",
true,
0.95,
0.95,
0
],
"color": "#232",
"bgcolor": "#353"
},
{
"id": 8,
"type": "LoadAudio",
@@ -227,6 +156,94 @@
"widgets_values": [
"audio/VibeVoice"
]
},
{
"id": 11,
"type": "VibeVoiceTTS",
"pos": [
-1570,
-1130
],
"size": [
460,
510
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "speaker_1_voice",
"shape": 7,
"type": "AUDIO",
"link": 28
},
{
"name": "speaker_2_voice",
"shape": 7,
"type": "AUDIO",
"link": 29
},
{
"name": "speaker_3_voice",
"shape": 7,
"type": "AUDIO",
"link": null
},
{
"name": "speaker_4_voice",
"shape": 7,
"type": "AUDIO",
"link": null
}
],
"outputs": [
{
"name": "AUDIO",
"type": "AUDIO",
"links": [
27
]
}
],
"properties": {
"cnr_id": "ComfyUI-VibeVoice",
"ver": "37803a884fb8f9b43c38286f6d654c7f97181a73",
"Node name for S&R": "VibeVoiceTTS",
"ue_properties": {
"widget_ue_connectable": {
"model_name": true,
"text": true,
"quantize_llm_4bit": true,
"attention_mode": true,
"cfg_scale": true,
"inference_steps": true,
"seed": true,
"do_sample": true,
"temperature": true,
"top_p": true,
"top_k": true
},
"version": "7.0.1"
}
},
"widgets_values": [
"VibeVoice-1.5B",
"Speaker 1: I can't believe you did it again. I waited for two hours. Two hours! Not a single call, not a text. Do you have any idea how embarrassing that was, just sitting there alone?\nSpeaker 2: Look, I know, I'm sorry, alright? Work was a complete nightmare. My boss dropped a critical deadline on me at the last minute. I didn't even have a second to breathe, let alone check my phone.\nSpeaker 1: A nightmare? That's the same excuse you used last time. I'm starting to think you just don't care. It's easier to say 'work was crazy' than to just admit that I'm not a priority for you anymore.",
false,
"flash_attention_2",
1.3,
10,
1,
"fixed",
true,
0.95,
0.95,
0,
false
],
"color": "#232",
"bgcolor": "#353"
}
],
"links": [
@@ -261,10 +278,10 @@
"ue_links": [],
"links_added_by_ue": [],
"ds": {
"scale": 1.2100000000000004,
"scale": 1.2100000000000002,
"offset": [
2024.7933884297524,
1252.3140495867776
2000,
1230
]
},
"frontendVersion": "1.25.11",

Binary file not shown.

Before

Width:  |  Height:  |  Size: 138 KiB

After

Width:  |  Height:  |  Size: 145 KiB

View File

@@ -55,7 +55,7 @@
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000.0,
"sliding_window": null,
"sliding_window": 4096,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"use_cache": true,

View File

@@ -54,7 +54,7 @@
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": null,
"sliding_window": 4096,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.1",

View File

@@ -20,6 +20,10 @@ from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from .vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
from .vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast
from . import SAGE_ATTENTION_AVAILABLE
if SAGE_ATTENTION_AVAILABLE:
from .vibevoice.modular.sage_attention_patch import set_sage_attention
try:
import librosa
except ImportError:
@@ -40,11 +44,13 @@ MODEL_CONFIGS = {
"VibeVoice-Large": {
"repo_id": "microsoft/VibeVoice-Large",
"size_gb": 17.4,
"tokenizer_repo": "Qwen/Qwen2.5-7B"
"tokenizer_repo": "Qwen/Qwen2.5-7B"
}
}
ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"]
if SAGE_ATTENTION_AVAILABLE:
ATTENTION_MODES.append("sage")
def cleanup_old_models(keep_cache_key=None):
"""Clean up old models, optionally keeping one specific model loaded"""
@@ -61,14 +67,11 @@ def cleanup_old_models(keep_cache_key=None):
# Clear VIBEVOICE_PATCHER_CACHE - but more carefully
for key in list(VIBEVOICE_PATCHER_CACHE.keys()):
if key != keep_cache_key:
# Set the model/processor to None but don't delete the patcher itself
# This lets ComfyUI's model management handle the patcher cleanup
try:
patcher = VIBEVOICE_PATCHER_CACHE[key]
if hasattr(patcher, 'model') and patcher.model:
patcher.model.model = None
patcher.model.processor = None
# Remove from our cache but let ComfyUI handle the rest
del VIBEVOICE_PATCHER_CACHE[key]
except Exception as e:
logger.warning(f"Error cleaning up patcher {key}: {e}")
@@ -85,7 +88,7 @@ class VibeVoiceModelHandler(torch.nn.Module):
self.model_pack_name = model_pack_name
self.attention_mode = attention_mode
self.use_llm_4bit = use_llm_4bit
self.cache_key = f"{model_pack_name}_attn_{attention_mode}"
self.cache_key = f"{model_pack_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}"
self.model = None
self.processor = None
self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3))
@@ -113,7 +116,8 @@ class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
mode_names = {
"eager": "Eager (Most Compatible)",
"sdpa": "SDPA (Balanced Speed/Compatibility)",
"flash_attention_2": "Flash Attention 2 (Fastest)"
"flash_attention_2": "Flash Attention 2 (Fastest)",
"sage": "SageAttention (Quantized High-Performance)",
}
logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}")
self.model.load_model(target_device, self.attention_mode)
@@ -126,15 +130,10 @@ class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
self.model.model = None
self.model.processor = None
# Clear using the correct cache key
if self.cache_key in LOADED_MODELS:
del LOADED_MODELS[self.cache_key]
logger.info(f"Cleared LOADED_MODELS cache for: {self.cache_key}")
# DON'T delete from VIBEVOICE_PATCHER_CACHE here - let ComfyUI handle it
# This prevents the IndexError in ComfyUI's model management
# Force garbage collection
gc.collect()
model_management.soft_empty_cache()
@@ -157,145 +156,112 @@ class VibeVoiceLoader:
return model_path
@staticmethod
def _check_attention_compatibility(attention_mode: str, torch_dtype, device_name: str = ""):
"""Check if the requested attention mode is compatible with current setup."""
def _check_gpu_for_sage_attention():
"""Check if the current GPU is compatible with SageAttention."""
if not SAGE_ATTENTION_AVAILABLE:
return False
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
if major < 8:
logger.warning(f"Your GPU (compute capability {major}.x) does not support SageAttention, which requires CC 8.0+. Sage option will be disabled.")
return False
return True
# Check for SDPA availability (PyTorch 2.0+)
if attention_mode == "sdpa":
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
logger.warning("SDPA not available (requires PyTorch 2.0+), falling back to eager")
return "eager"
# Check for Flash Attention availability
elif attention_mode == "flash_attention_2":
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
logger.warning("Flash Attention not available, falling back to eager")
return "eager"
elif torch_dtype == torch.float32:
logger.warning("Flash Attention not recommended with float32, falling back to SDPA")
return "sdpa" if hasattr(torch.nn.functional, 'scaled_dot_product_attention') else "eager"
# Just informational messages, no forced fallbacks
if device_name and torch.cuda.is_available():
if "RTX 50" in device_name or "Blackwell" in device_name:
if attention_mode == "flash_attention_2":
logger.info(f"Using Flash Attention on {device_name}")
elif attention_mode == "sdpa":
logger.info(f"Using SDPA on {device_name}")
return attention_mode
@staticmethod
def load_model(model_name: str, device, attention_mode: str = "eager", use_llm_4bit: bool = False):
# Validate attention mode
if use_llm_4bit and attention_mode in ["eager", "flash_attention_2"]:
logger.warning(f"Attention mode '{attention_mode}' is not recommended with 4-bit quantization. Falling back to 'sdpa' for stability and performance.")
attention_mode = "sdpa"
if attention_mode not in ATTENTION_MODES:
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
attention_mode = "eager"
if use_llm_4bit and attention_mode == "flash_attention_2":
attention_mode = "sdpa"
# Create cache key that includes attention mode
cache_key = f"{model_name}_attn_{attention_mode}"
cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(use_llm_4bit)}"
if cache_key in LOADED_MODELS:
logger.info(f"Using cached model with {attention_mode} attention")
logger.info(f"Using cached model with {attention_mode} attention and q4={use_llm_4bit}")
return LOADED_MODELS[cache_key]
model_path = VibeVoiceLoader.get_model_path(model_name)
logger.info(f"Loading VibeVoice model components from: {model_path}")
tokenizer_repo = MODEL_CONFIGS[model_name].get("tokenizer_repo")
try:
tokenizer_file_path = hf_hub_download(repo_id=tokenizer_repo, filename="tokenizer.json")
except Exception as e:
raise RuntimeError(f"Could not download tokenizer.json for {tokenizer_repo}. Error: {e}")
tokenizer_file_path = hf_hub_download(repo_id=tokenizer_repo, filename="tokenizer.json")
vibevoice_tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file_path)
audio_processor = VibeVoiceTokenizerProcessor()
processor = VibeVoiceProcessor(tokenizer=vibevoice_tokenizer, audio_processor=audio_processor)
torch_dtype = model_management.text_encoder_dtype(device)
device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else ""
# Check compatibility and potentially fall back to safer mode
final_attention_mode = VibeVoiceLoader._check_attention_compatibility(
attention_mode, torch_dtype, device_name
)
# Base dtype for full precision and memory-optimized 4-bit
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
model_dtype = torch.bfloat16
else:
model_dtype = torch.float16
# Build optional 4-bit config (LLM only)
quant_config = None
final_load_dtype = model_dtype
if use_llm_4bit:
# Default to bfloat16/float16 for memory savings
bnb_compute_dtype = model_dtype
# SageAttention is numerically sensitive and requires fp32 compute dtype for stability
# SDPA is more robust and can use bf16.
if attention_mode == 'sage':
logger.info("Using SageAttention with 4-bit quant. Forcing fp32 compute dtype for maximum stability.")
bnb_compute_dtype = torch.float32
final_load_dtype = torch.float32
else:
logger.info(f"Using {attention_mode} with 4-bit quant. Using {model_dtype} compute dtype for memory efficiency.")
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_compute_dtype=bnb_compute_dtype,
)
logger.info(f"Requested attention mode: {attention_mode}")
if final_attention_mode != attention_mode:
logger.info(f"Using attention mode: {final_attention_mode} (automatic fallback)")
# Update cache key to reflect actual mode used
cache_key = f"{model_name}_attn_{final_attention_mode}"
if cache_key in LOADED_MODELS:
return LOADED_MODELS[cache_key]
else:
logger.info(f"Using attention mode: {final_attention_mode}")
logger.info(f"Final attention implementation: {final_attention_mode}")
# Modify config for non-flash attention modes
if final_attention_mode in ["eager", "sdpa"]:
import json
config_path = os.path.join(model_path, "config.json")
if os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f)
# Remove flash attention settings
removed_keys = []
for key in ['_attn_implementation', 'attn_implementation', 'use_flash_attention_2']:
if key in config:
config.pop(key)
removed_keys.append(key)
if removed_keys:
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
logger.info(f"Removed FlashAttention settings from config.json: {removed_keys}")
except Exception as e:
logger.warning(f"Could not modify config.json: {e}")
attn_implementation_for_load = "sdpa" if attention_mode == "sage" else attention_mode
try:
logger.info(f"Loading model with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'")
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
model_path,
torch_dtype=torch.bfloat16 if quant_config else torch_dtype,
attn_implementation=final_attention_mode,
dtype=final_load_dtype,
attn_implementation=attn_implementation_for_load,
device_map="auto" if quant_config else device,
quantization_config=quant_config, # <- forwarded if supported
quantization_config=quant_config,
)
if attention_mode == "sage":
if VibeVoiceLoader._check_gpu_for_sage_attention():
logger.info("Applying SageAttention patch to the model...")
set_sage_attention(model)
else:
logger.error("Cannot apply SageAttention due to incompatible GPU. Falling back.")
raise RuntimeError("Incompatible hardware/setup for SageAttention.")
model.eval()
setattr(model, "_llm_4bit", bool(quant_config))
# Store with the actual attention mode used (not the requested one)
LOADED_MODELS[cache_key] = (model, processor)
logger.info(f"Successfully loaded model with {final_attention_mode} attention")
logger.info(f"Successfully configured model with {attention_mode} attention")
return model, processor
except Exception as e:
logger.error(f"Failed to load model with {final_attention_mode} attention: {e}")
# Progressive fallback: flash -> sdpa -> eager
if final_attention_mode == "flash_attention_2":
logger.error(f"Failed to load model with {attention_mode} attention: {e}")
# Fallback logic
if attention_mode in ["sage", "flash_attention_2"]:
logger.info("Attempting fallback to SDPA...")
return VibeVoiceLoader.load_model(model_name, device, "sdpa")
elif final_attention_mode == "sdpa":
return VibeVoiceLoader.load_model(model_name, device, "sdpa", use_llm_4bit)
elif attention_mode == "sdpa":
logger.info("Attempting fallback to eager...")
return VibeVoiceLoader.load_model(model_name, device, "eager")
return VibeVoiceLoader.load_model(model_name, device, "eager", use_llm_4bit)
else:
# If eager fails, something is seriously wrong
raise RuntimeError(f"Failed to load model even with eager attention: {e}")
@@ -405,9 +371,9 @@ class VibeVoiceTTSNode:
"default": False, "label_on": "Q4 (LLM only)", "label_off": "Full precision",
"tooltip": "Quantize the Qwen2.5 LLM to 4-bit NF4 via bitsandbytes. Diffusion head stays BF16/FP32."
}),
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {
"attention_mode": (ATTENTION_MODES, {
"default": "sdpa",
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)"
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest), Sage (quantized)"
}),
"cfg_scale": ("FLOAT", {
"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05,
@@ -455,12 +421,12 @@ class VibeVoiceTTSNode:
CATEGORY = "audio/tts"
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, force_offload, **kwargs):
if not text.strip():
logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.")
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
# Create cache key that includes attention mode
cache_key = f"{model_name}_attn_{attention_mode}_q4_{int(quantize_llm_4bit)}"
actual_attention_mode = attention_mode
if quantize_llm_4bit and attention_mode in ["eager", "flash_attention_2"]:
actual_attention_mode = "sdpa"
cache_key = f"{model_name}_attn_{actual_attention_mode}_q4_{int(quantize_llm_4bit)}"
# Clean up old models when switching to a different model
if cache_key not in VIBEVOICE_PATCHER_CACHE:
@@ -489,7 +455,7 @@ class VibeVoiceTTSNode:
if not parsed_lines_0_based:
raise ValueError("Script is empty or invalid. Use 'Speaker 1:', 'Speaker 2:', etc. format.")
full_script = "\n".join([f"Speaker {spk}: {txt}" for spk, txt in parsed_lines_0_based])
full_script = "\n".join([f"Speaker {spk+1}: {txt}" for spk, txt in parsed_lines_0_based])
speaker_inputs = {i: kwargs.get(f"speaker_{i}_voice") for i in range(1, 5)}
voice_samples_np = [preprocess_comfy_audio(speaker_inputs[sid]) for sid in speaker_ids_1_based]
@@ -523,51 +489,24 @@ class VibeVoiceTTSNode:
generation_config['top_p'] = top_p
if top_k > 0:
generation_config['top_k'] = top_k
# Hardware-specific optimizations - only for eager mode
if attention_mode == "eager":
# Apply RTX 5090 / Blackwell compatibility fixes only for eager
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.cuda.empty_cache()
# Apply additional tensor fixes for eager mode
model = model.float()
processed_inputs = {}
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
# Keep integer/boolean tensors as-is (token IDs, attention masks, etc.)
if v.dtype in [torch.int, torch.long, torch.int32, torch.int64, torch.bool, torch.uint8]:
processed_inputs[k] = v
# Keep tensors with "mask" in their name as boolean
elif "mask" in k.lower():
processed_inputs[k] = v.bool() if v.dtype != torch.bool else v
else:
# Convert float/bfloat16 tensors to float32
processed_inputs[k] = v.float()
else:
processed_inputs[k] = v
inputs = processed_inputs
# cause float() error for q4+eager
# model = model.float() IS REMOVED
with torch.no_grad():
# Create progress bar for inference steps
pbar = ProgressBar(inference_steps)
def progress_callback(step, total_steps):
pbar.update(1)
# Check for interruption from ComfyUI
if model_management.interrupt_current_processing:
raise comfy.model_management.InterruptProcessingException()
# Custom generation loop with interruption support
try:
outputs = model.generate(
**inputs, max_new_tokens=None, cfg_scale=cfg_scale,
tokenizer=processor.tokenizer, generation_config=generation_config,
verbose=False, stop_check_fn=check_for_interrupt
)
# Note: The model.generate method doesn't support progress callbacks in the current VibeVoice implementation
# But we check for interruption at the start and end of generation
pbar.update(inference_steps - pbar.current)
except RuntimeError as e:
@@ -585,7 +524,6 @@ class VibeVoiceTTSNode:
except comfy.model_management.InterruptProcessingException:
logger.info("VibeVoice TTS generation was cancelled")
# Return silent audio on cancellation
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
except Exception as e:
@@ -599,13 +537,10 @@ class VibeVoiceTTSNode:
if output_waveform.ndim == 1: output_waveform = output_waveform.unsqueeze(0)
if output_waveform.ndim == 2: output_waveform = output_waveform.unsqueeze(0)
# Force offload model if requested
if force_offload:
logger.info(f"Force offloading VibeVoice model '{model_name}' from VRAM...")
# Force offload by unpatching the model and freeing memory
if patcher.is_loaded:
patcher.unpatch_model(unpatch_weights=True)
# Force unload all models to free memory
model_management.unload_all_models()
gc.collect()
model_management.soft_empty_cache()
@@ -614,4 +549,4 @@ class VibeVoiceTTSNode:
return ({"waveform": output_waveform.detach().cpu(), "sample_rate": 24000},)
NODE_CLASS_MAPPINGS = {"VibeVoiceTTS": VibeVoiceTTSNode}
NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"}
NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"}