mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-01-26 14:39:45 +00:00
Merge pull request #7 from Saganaki22/main
Add configurable attention modes with compatibility checks
This commit is contained in:
@@ -6,6 +6,7 @@ import random
|
||||
from huggingface_hub import snapshot_download
|
||||
import logging
|
||||
import librosa
|
||||
import gc
|
||||
|
||||
import folder_paths
|
||||
import comfy.model_management as model_management
|
||||
@@ -33,40 +34,94 @@ MODEL_CONFIGS = {
|
||||
}
|
||||
}
|
||||
|
||||
ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"]
|
||||
|
||||
def cleanup_old_models(keep_cache_key=None):
|
||||
"""Clean up old models, optionally keeping one specific model loaded"""
|
||||
global LOADED_MODELS, VIBEVOICE_PATCHER_CACHE
|
||||
|
||||
keys_to_remove = []
|
||||
|
||||
# Clear LOADED_MODELS
|
||||
for key in list(LOADED_MODELS.keys()):
|
||||
if key != keep_cache_key:
|
||||
keys_to_remove.append(key)
|
||||
del LOADED_MODELS[key]
|
||||
|
||||
# Clear VIBEVOICE_PATCHER_CACHE - but more carefully
|
||||
for key in list(VIBEVOICE_PATCHER_CACHE.keys()):
|
||||
if key != keep_cache_key:
|
||||
# Set the model/processor to None but don't delete the patcher itself
|
||||
# This lets ComfyUI's model management handle the patcher cleanup
|
||||
try:
|
||||
patcher = VIBEVOICE_PATCHER_CACHE[key]
|
||||
if hasattr(patcher, 'model') and patcher.model:
|
||||
patcher.model.model = None
|
||||
patcher.model.processor = None
|
||||
# Remove from our cache but let ComfyUI handle the rest
|
||||
del VIBEVOICE_PATCHER_CACHE[key]
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up patcher {key}: {e}")
|
||||
|
||||
if keys_to_remove:
|
||||
logger.info(f"Cleaned up cached models: {keys_to_remove}")
|
||||
gc.collect()
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
class VibeVoiceModelHandler(torch.nn.Module):
|
||||
"""A torch.nn.Module wrapper to hold the VibeVoice model and processor."""
|
||||
def __init__(self, model_pack_name):
|
||||
def __init__(self, model_pack_name, attention_mode="eager"):
|
||||
super().__init__()
|
||||
self.model_pack_name = model_pack_name
|
||||
self.attention_mode = attention_mode
|
||||
self.cache_key = f"{model_pack_name}_attn_{attention_mode}"
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3))
|
||||
|
||||
def load_model(self, device):
|
||||
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name)
|
||||
def load_model(self, device, attention_mode="eager"):
|
||||
self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, attention_mode)
|
||||
self.model.to(device)
|
||||
|
||||
class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
|
||||
"""Custom ModelPatcher for managing VibeVoice models in ComfyUI."""
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
def __init__(self, model, attention_mode="eager", *args, **kwargs):
|
||||
super().__init__(model, *args, **kwargs)
|
||||
self.attention_mode = attention_mode
|
||||
self.cache_key = model.cache_key
|
||||
|
||||
def patch_model(self, device_to=None, *args, **kwargs):
|
||||
target_device = self.load_device
|
||||
if self.model.model is None:
|
||||
logger.info(f"Loading VibeVoice models for '{self.model.model_pack_name}' to {target_device}...")
|
||||
self.model.load_model(target_device)
|
||||
mode_names = {
|
||||
"eager": "Eager (Most Compatible)",
|
||||
"sdpa": "SDPA (Balanced Speed/Compatibility)",
|
||||
"flash_attention_2": "Flash Attention 2 (Fastest)"
|
||||
}
|
||||
logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}")
|
||||
self.model.load_model(target_device, self.attention_mode)
|
||||
self.model.model.to(target_device)
|
||||
return super().patch_model(device_to=target_device, *args, **kwargs)
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs):
|
||||
if unpatch_weights:
|
||||
logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' to {device_to}...")
|
||||
logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' ({self.attention_mode}) to {device_to}...")
|
||||
self.model.model = None
|
||||
self.model.processor = None
|
||||
if self.model.model_pack_name in LOADED_MODELS:
|
||||
del LOADED_MODELS[self.model.model_pack_name]
|
||||
|
||||
# Clear using the correct cache key
|
||||
if self.cache_key in LOADED_MODELS:
|
||||
del LOADED_MODELS[self.cache_key]
|
||||
logger.info(f"Cleared LOADED_MODELS cache for: {self.cache_key}")
|
||||
|
||||
# DON'T delete from VIBEVOICE_PATCHER_CACHE here - let ComfyUI handle it
|
||||
# This prevents the IndexError in ComfyUI's model management
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs)
|
||||
|
||||
class VibeVoiceLoader:
|
||||
@@ -86,9 +141,47 @@ class VibeVoiceLoader:
|
||||
return model_path
|
||||
|
||||
@staticmethod
|
||||
def load_model(model_name: str):
|
||||
if model_name in LOADED_MODELS:
|
||||
return LOADED_MODELS[model_name]
|
||||
def _check_attention_compatibility(attention_mode: str, torch_dtype, device_name: str = ""):
|
||||
"""Check if the requested attention mode is compatible with current setup."""
|
||||
|
||||
# Check for SDPA availability (PyTorch 2.0+)
|
||||
if attention_mode == "sdpa":
|
||||
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
logger.warning("SDPA not available (requires PyTorch 2.0+), falling back to eager")
|
||||
return "eager"
|
||||
|
||||
# Check for Flash Attention availability
|
||||
elif attention_mode == "flash_attention_2":
|
||||
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
logger.warning("Flash Attention not available, falling back to eager")
|
||||
return "eager"
|
||||
elif torch_dtype == torch.float32:
|
||||
logger.warning("Flash Attention not recommended with float32, falling back to SDPA")
|
||||
return "sdpa" if hasattr(torch.nn.functional, 'scaled_dot_product_attention') else "eager"
|
||||
|
||||
# Just informational messages, no forced fallbacks
|
||||
if device_name and torch.cuda.is_available():
|
||||
if "RTX 50" in device_name or "Blackwell" in device_name:
|
||||
if attention_mode == "flash_attention_2":
|
||||
logger.info(f"Using Flash Attention on {device_name}")
|
||||
elif attention_mode == "sdpa":
|
||||
logger.info(f"Using SDPA on {device_name}")
|
||||
|
||||
return attention_mode
|
||||
|
||||
@staticmethod
|
||||
def load_model(model_name: str, attention_mode: str = "eager"):
|
||||
# Validate attention mode
|
||||
if attention_mode not in ATTENTION_MODES:
|
||||
logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager")
|
||||
attention_mode = "eager"
|
||||
|
||||
# Create cache key that includes attention mode
|
||||
cache_key = f"{model_name}_attn_{attention_mode}"
|
||||
|
||||
if cache_key in LOADED_MODELS:
|
||||
logger.info(f"Using cached model with {attention_mode} attention")
|
||||
return LOADED_MODELS[cache_key]
|
||||
|
||||
model_path = VibeVoiceLoader.get_model_path(model_name)
|
||||
|
||||
@@ -96,16 +189,74 @@ class VibeVoiceLoader:
|
||||
processor = VibeVoiceProcessor.from_pretrained(model_path)
|
||||
|
||||
torch_dtype = model_management.text_encoder_dtype(model_management.get_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else ""
|
||||
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="flash_attention_2" if hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch_dtype != torch.float32 else "eager",
|
||||
# Check compatibility and potentially fall back to safer mode
|
||||
final_attention_mode = VibeVoiceLoader._check_attention_compatibility(
|
||||
attention_mode, torch_dtype, device_name
|
||||
)
|
||||
model.eval()
|
||||
|
||||
LOADED_MODELS[model_name] = (model, processor)
|
||||
return model, processor
|
||||
print(f"Requested attention mode: {attention_mode}")
|
||||
if final_attention_mode != attention_mode:
|
||||
print(f"Using attention mode: {final_attention_mode} (automatic fallback)")
|
||||
# Update cache key to reflect actual mode used
|
||||
cache_key = f"{model_name}_attn_{final_attention_mode}"
|
||||
if cache_key in LOADED_MODELS:
|
||||
return LOADED_MODELS[cache_key]
|
||||
else:
|
||||
print(f"Using attention mode: {final_attention_mode}")
|
||||
|
||||
logger.info(f"Final attention implementation: {final_attention_mode}")
|
||||
|
||||
# Modify config for non-flash attention modes
|
||||
if final_attention_mode in ["eager", "sdpa"]:
|
||||
import json
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Remove flash attention settings
|
||||
removed_keys = []
|
||||
for key in ['_attn_implementation', 'attn_implementation', 'use_flash_attention_2']:
|
||||
if key in config:
|
||||
config.pop(key)
|
||||
removed_keys.append(key)
|
||||
|
||||
if removed_keys:
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
logger.info(f"Removed FlashAttention settings from config.json: {removed_keys}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not modify config.json: {e}")
|
||||
|
||||
try:
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation=final_attention_mode,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Store with the actual attention mode used (not the requested one)
|
||||
LOADED_MODELS[cache_key] = (model, processor)
|
||||
logger.info(f"Successfully loaded model with {final_attention_mode} attention")
|
||||
return model, processor
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model with {final_attention_mode} attention: {e}")
|
||||
|
||||
# Progressive fallback: flash -> sdpa -> eager
|
||||
if final_attention_mode == "flash_attention_2":
|
||||
logger.info("Attempting fallback to SDPA...")
|
||||
return VibeVoiceLoader.load_model(model_name, "sdpa")
|
||||
elif final_attention_mode == "sdpa":
|
||||
logger.info("Attempting fallback to eager...")
|
||||
return VibeVoiceLoader.load_model(model_name, "eager")
|
||||
else:
|
||||
# If eager fails, something is seriously wrong
|
||||
raise RuntimeError(f"Failed to load model even with eager attention: {e}")
|
||||
|
||||
|
||||
def set_vibevoice_seed(seed: int):
|
||||
@@ -162,10 +313,30 @@ def preprocess_comfy_audio(audio_dict: dict, target_sr: int = 24000) -> np.ndarr
|
||||
if waveform.ndim > 1:
|
||||
waveform = np.mean(waveform, axis=0)
|
||||
|
||||
# Check for invalid values
|
||||
if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)):
|
||||
logger.error("Audio contains NaN or Inf values, replacing with zeros")
|
||||
waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# Ensure audio is not completely silent or has extreme values
|
||||
if np.all(waveform == 0):
|
||||
logger.warning("Audio waveform is completely silent")
|
||||
|
||||
# Normalize extreme values
|
||||
max_val = np.abs(waveform).max()
|
||||
if max_val > 10.0:
|
||||
logger.warning(f"Audio values are very large (max: {max_val}), normalizing")
|
||||
waveform = waveform / max_val
|
||||
|
||||
if original_sr != target_sr:
|
||||
logger.warning(f"Resampling reference audio from {original_sr}Hz to {target_sr}Hz.")
|
||||
waveform = librosa.resample(y=waveform, orig_sr=original_sr, target_sr=target_sr)
|
||||
|
||||
# Final check after resampling
|
||||
if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)):
|
||||
logger.error("Audio contains NaN or Inf after resampling, replacing with zeros")
|
||||
waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
return waveform.astype(np.float32)
|
||||
|
||||
|
||||
@@ -182,6 +353,10 @@ class VibeVoiceTTSNode:
|
||||
"default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.",
|
||||
"tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line."
|
||||
}),
|
||||
"attention_mode": (["eager", "sdpa", "flash_attention_2"], {
|
||||
"default": "sdpa",
|
||||
"tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)"
|
||||
}),
|
||||
"cfg_scale": ("FLOAT", {
|
||||
"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05,
|
||||
"tooltip": "Classifier-Free Guidance scale. Higher values increase adherence to the voice prompt but may reduce naturalness. Recommended: 1.3"
|
||||
@@ -223,16 +398,23 @@ class VibeVoiceTTSNode:
|
||||
FUNCTION = "generate_audio"
|
||||
CATEGORY = "audio/tts"
|
||||
|
||||
def generate_audio(self, model_name, text, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs):
|
||||
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs):
|
||||
if not text.strip():
|
||||
logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.")
|
||||
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
||||
|
||||
cache_key = model_name
|
||||
# Create cache key that includes attention mode
|
||||
cache_key = f"{model_name}_attn_{attention_mode}"
|
||||
|
||||
# Clean up old models when switching to a different model
|
||||
if cache_key not in VIBEVOICE_PATCHER_CACHE:
|
||||
model_handler = VibeVoiceModelHandler(model_name)
|
||||
# Only keep models that are currently being requested
|
||||
cleanup_old_models(keep_cache_key=cache_key)
|
||||
|
||||
model_handler = VibeVoiceModelHandler(model_name, attention_mode)
|
||||
patcher = VibeVoicePatcher(
|
||||
model_handler,
|
||||
attention_mode=attention_mode,
|
||||
load_device=model_management.get_torch_device(),
|
||||
offload_device=model_management.unet_offload_device(),
|
||||
size=model_handler.size
|
||||
@@ -262,28 +444,101 @@ class VibeVoiceTTSNode:
|
||||
|
||||
set_vibevoice_seed(seed)
|
||||
|
||||
inputs = processor(
|
||||
text=[full_script], voice_samples=[voice_samples_np], padding=True,
|
||||
return_tensors="pt", return_attention_mask=True
|
||||
)
|
||||
inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
||||
|
||||
model.set_ddpm_inference_steps(num_steps=inference_steps)
|
||||
|
||||
generation_config = {'do_sample': do_sample}
|
||||
if do_sample:
|
||||
generation_config['temperature'] = temperature
|
||||
generation_config['top_p'] = top_p
|
||||
if top_k > 0:
|
||||
generation_config['top_k'] = top_k
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs, max_new_tokens=None, cfg_scale=cfg_scale,
|
||||
tokenizer=processor.tokenizer, generation_config=generation_config,
|
||||
verbose=False
|
||||
try:
|
||||
inputs = processor(
|
||||
text=[full_script], voice_samples=[voice_samples_np], padding=True,
|
||||
return_tensors="pt", return_attention_mask=True
|
||||
)
|
||||
|
||||
# Validate inputs before moving to GPU
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
if torch.any(torch.isnan(value)) or torch.any(torch.isinf(value)):
|
||||
logger.error(f"Input tensor '{key}' contains NaN or Inf values")
|
||||
raise ValueError(f"Invalid values in input tensor: {key}")
|
||||
|
||||
inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
||||
|
||||
model.set_ddpm_inference_steps(num_steps=inference_steps)
|
||||
|
||||
generation_config = {'do_sample': do_sample}
|
||||
if do_sample:
|
||||
generation_config['temperature'] = temperature
|
||||
generation_config['top_p'] = top_p
|
||||
if top_k > 0:
|
||||
generation_config['top_k'] = top_k
|
||||
|
||||
# Hardware-specific optimizations - only for eager mode
|
||||
if attention_mode == "eager":
|
||||
# Apply RTX 5090 / Blackwell compatibility fixes only for eager
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Apply additional tensor fixes for eager mode
|
||||
model = model.float()
|
||||
processed_inputs = {}
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Keep integer/boolean tensors as-is (token IDs, attention masks, etc.)
|
||||
if v.dtype in [torch.int, torch.long, torch.int32, torch.int64, torch.bool, torch.uint8]:
|
||||
processed_inputs[k] = v
|
||||
# Keep tensors with "mask" in their name as boolean
|
||||
elif "mask" in k.lower():
|
||||
processed_inputs[k] = v.bool() if v.dtype != torch.bool else v
|
||||
else:
|
||||
# Convert float/bfloat16 tensors to float32
|
||||
processed_inputs[k] = v.float()
|
||||
else:
|
||||
processed_inputs[k] = v
|
||||
inputs = processed_inputs
|
||||
|
||||
with torch.no_grad():
|
||||
# Create progress bar for inference steps
|
||||
pbar = ProgressBar(inference_steps)
|
||||
|
||||
def progress_callback(step, total_steps):
|
||||
pbar.update(1)
|
||||
# Check for interruption from ComfyUI
|
||||
if model_management.interrupt_current_processing:
|
||||
raise comfy.model_management.InterruptProcessingException()
|
||||
|
||||
# Custom generation loop with interruption support
|
||||
try:
|
||||
outputs = model.generate(
|
||||
**inputs, max_new_tokens=None, cfg_scale=cfg_scale,
|
||||
tokenizer=processor.tokenizer, generation_config=generation_config,
|
||||
verbose=False
|
||||
)
|
||||
# Note: The model.generate method doesn't support progress callbacks in the current VibeVoice implementation
|
||||
# But we check for interruption at the start and end of generation
|
||||
pbar.update(inference_steps - pbar.current)
|
||||
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "assertion" in error_msg or "cuda" in error_msg:
|
||||
logger.error(f"CUDA assertion failed with {attention_mode} attention: {e}")
|
||||
logger.error("This might be due to invalid input data, GPU memory issues, or incompatible attention mode.")
|
||||
logger.error("Try restarting ComfyUI, using different audio files, or switching to 'eager' attention mode.")
|
||||
raise e
|
||||
except comfy.model_management.InterruptProcessingException:
|
||||
logger.info("VibeVoice generation interrupted by user")
|
||||
raise
|
||||
finally:
|
||||
pbar.update_absolute(inference_steps)
|
||||
|
||||
except comfy.model_management.InterruptProcessingException:
|
||||
logger.info("VibeVoice TTS generation was cancelled")
|
||||
# Return silent audio on cancellation
|
||||
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during VibeVoice generation with {attention_mode} attention: {e}")
|
||||
if "interrupt" in str(e).lower() or "cancel" in str(e).lower():
|
||||
logger.info("Generation was interrupted")
|
||||
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
||||
raise
|
||||
|
||||
output_waveform = outputs.speech_outputs[0]
|
||||
if output_waveform.ndim == 1: output_waveform = output_waveform.unsqueeze(0)
|
||||
if output_waveform.ndim == 2: output_waveform = output_waveform.unsqueeze(0)
|
||||
|
||||
Reference in New Issue
Block a user