From c29daa905076e8913c3703d5f150c9fecd2a91fc Mon Sep 17 00:00:00 2001 From: drbaph <84208527+Saganaki22@users.noreply.github.com> Date: Thu, 28 Aug 2025 01:51:29 +0100 Subject: [PATCH 1/2] Add configurable attention modes with compatibility checks - Added dropdown selection for attention implementation (eager/sdpa/flash_attention_2) - Implemented automatic compatibility checks and progressive fallbacks - Added hardware-specific optimizations for RTX 5090/Blackwell GPUs - Enhanced error handling to prevent crashes from incompatible attention modes --- vibevoice_nodes.py | 285 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 245 insertions(+), 40 deletions(-) diff --git a/vibevoice_nodes.py b/vibevoice_nodes.py index 540ea0b..25e1aac 100644 --- a/vibevoice_nodes.py +++ b/vibevoice_nodes.py @@ -33,6 +33,8 @@ MODEL_CONFIGS = { } } +ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"] + class VibeVoiceModelHandler(torch.nn.Module): """A torch.nn.Module wrapper to hold the VibeVoice model and processor.""" def __init__(self, model_pack_name): @@ -42,20 +44,27 @@ class VibeVoiceModelHandler(torch.nn.Module): self.processor = None self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3)) - def load_model(self, device): - self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name) + def load_model(self, device, attention_mode="eager"): + self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, attention_mode) self.model.to(device) class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): """Custom ModelPatcher for managing VibeVoice models in ComfyUI.""" - def __init__(self, model, *args, **kwargs): + def __init__(self, model, attention_mode="eager", *args, **kwargs): super().__init__(model, *args, **kwargs) + self.attention_mode = attention_mode def patch_model(self, device_to=None, *args, **kwargs): target_device = self.load_device if self.model.model is None: logger.info(f"Loading VibeVoice models for '{self.model.model_pack_name}' to {target_device}...") - self.model.load_model(target_device) + mode_names = { + "eager": "Eager (Most Compatible)", + "sdpa": "SDPA (Balanced Speed/Compatibility)", + "flash_attention_2": "Flash Attention 2 (Fastest)" + } + logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}") + self.model.load_model(target_device, self.attention_mode) self.model.model.to(target_device) return super().patch_model(device_to=target_device, *args, **kwargs) @@ -86,9 +95,47 @@ class VibeVoiceLoader: return model_path @staticmethod - def load_model(model_name: str): - if model_name in LOADED_MODELS: - return LOADED_MODELS[model_name] + def _check_attention_compatibility(attention_mode: str, torch_dtype, device_name: str = ""): + """Check if the requested attention mode is compatible with current setup.""" + + # Check for SDPA availability (PyTorch 2.0+) + if attention_mode == "sdpa": + if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): + logger.warning("SDPA not available (requires PyTorch 2.0+), falling back to eager") + return "eager" + + # Check for Flash Attention availability + elif attention_mode == "flash_attention_2": + if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): + logger.warning("Flash Attention not available, falling back to eager") + return "eager" + elif torch_dtype == torch.float32: + logger.warning("Flash Attention not recommended with float32, falling back to SDPA") + return "sdpa" if hasattr(torch.nn.functional, 'scaled_dot_product_attention') else "eager" + + # Just informational messages, no forced fallbacks + if device_name and torch.cuda.is_available(): + if "RTX 50" in device_name or "Blackwell" in device_name: + if attention_mode == "flash_attention_2": + logger.info(f"Using Flash Attention on {device_name}") + elif attention_mode == "sdpa": + logger.info(f"Using SDPA on {device_name}") + + return attention_mode + + @staticmethod + def load_model(model_name: str, attention_mode: str = "eager"): + # Validate attention mode + if attention_mode not in ATTENTION_MODES: + logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager") + attention_mode = "eager" + + # Create cache key that includes attention mode + cache_key = f"{model_name}_attn_{attention_mode}" + + if cache_key in LOADED_MODELS: + logger.info(f"Using cached model with {attention_mode} attention") + return LOADED_MODELS[cache_key] model_path = VibeVoiceLoader.get_model_path(model_name) @@ -96,16 +143,74 @@ class VibeVoiceLoader: processor = VibeVoiceProcessor.from_pretrained(model_path) torch_dtype = model_management.text_encoder_dtype(model_management.get_torch_device()) - - model = VibeVoiceForConditionalGenerationInference.from_pretrained( - model_path, - torch_dtype=torch_dtype, - attn_implementation="flash_attention_2" if hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch_dtype != torch.float32 else "eager", - ) - model.eval() + device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "" - LOADED_MODELS[model_name] = (model, processor) - return model, processor + # Check compatibility and potentially fall back to safer mode + final_attention_mode = VibeVoiceLoader._check_attention_compatibility( + attention_mode, torch_dtype, device_name + ) + + print(f"Requested attention mode: {attention_mode}") + if final_attention_mode != attention_mode: + print(f"Using attention mode: {final_attention_mode} (automatic fallback)") + # Update cache key to reflect actual mode used + cache_key = f"{model_name}_attn_{final_attention_mode}" + if cache_key in LOADED_MODELS: + return LOADED_MODELS[cache_key] + else: + print(f"Using attention mode: {final_attention_mode}") + + logger.info(f"Final attention implementation: {final_attention_mode}") + + # Modify config for non-flash attention modes + if final_attention_mode in ["eager", "sdpa"]: + import json + config_path = os.path.join(model_path, "config.json") + if os.path.exists(config_path): + try: + with open(config_path, 'r') as f: + config = json.load(f) + + # Remove flash attention settings + removed_keys = [] + for key in ['_attn_implementation', 'attn_implementation', 'use_flash_attention_2']: + if key in config: + config.pop(key) + removed_keys.append(key) + + if removed_keys: + with open(config_path, 'w') as f: + json.dump(config, f, indent=2) + logger.info(f"Removed FlashAttention settings from config.json: {removed_keys}") + except Exception as e: + logger.warning(f"Could not modify config.json: {e}") + + try: + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + model_path, + torch_dtype=torch_dtype, + attn_implementation=final_attention_mode, + ) + model.eval() + + # Store with the actual attention mode used (not the requested one) + LOADED_MODELS[cache_key] = (model, processor) + logger.info(f"Successfully loaded model with {final_attention_mode} attention") + return model, processor + + except Exception as e: + logger.error(f"Failed to load model with {final_attention_mode} attention: {e}") + + # Progressive fallback: flash -> sdpa -> eager + if final_attention_mode == "flash_attention_2": + logger.info("Attempting fallback to SDPA...") + return VibeVoiceLoader.load_model(model_name, "sdpa") + elif final_attention_mode == "sdpa": + logger.info("Attempting fallback to eager...") + return VibeVoiceLoader.load_model(model_name, "eager") + else: + # If eager fails, something is seriously wrong + raise RuntimeError(f"Failed to load model even with eager attention: {e}") def set_vibevoice_seed(seed: int): @@ -162,9 +267,29 @@ def preprocess_comfy_audio(audio_dict: dict, target_sr: int = 24000) -> np.ndarr if waveform.ndim > 1: waveform = np.mean(waveform, axis=0) + # Check for invalid values + if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): + logger.error("Audio contains NaN or Inf values, replacing with zeros") + waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) + + # Ensure audio is not completely silent or has extreme values + if np.all(waveform == 0): + logger.warning("Audio waveform is completely silent") + + # Normalize extreme values + max_val = np.abs(waveform).max() + if max_val > 10.0: + logger.warning(f"Audio values are very large (max: {max_val}), normalizing") + waveform = waveform / max_val + if original_sr != target_sr: logger.warning(f"Resampling reference audio from {original_sr}Hz to {target_sr}Hz.") waveform = librosa.resample(y=waveform, orig_sr=original_sr, target_sr=target_sr) + + # Final check after resampling + if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): + logger.error("Audio contains NaN or Inf after resampling, replacing with zeros") + waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) return waveform.astype(np.float32) @@ -182,6 +307,10 @@ class VibeVoiceTTSNode: "default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.", "tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line." }), + "attention_mode": (["eager", "sdpa", "flash_attention_2"], { + "default": "sdpa", + "tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)" + }), "cfg_scale": ("FLOAT", { "default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05, "tooltip": "Classifier-Free Guidance scale. Higher values increase adherence to the voice prompt but may reduce naturalness. Recommended: 1.3" @@ -223,16 +352,19 @@ class VibeVoiceTTSNode: FUNCTION = "generate_audio" CATEGORY = "audio/tts" - def generate_audio(self, model_name, text, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs): + def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs): if not text.strip(): logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.") return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) - cache_key = model_name + # Create cache key that includes attention mode + cache_key = f"{model_name}_attn_{attention_mode}" + if cache_key not in VIBEVOICE_PATCHER_CACHE: model_handler = VibeVoiceModelHandler(model_name) patcher = VibeVoicePatcher( - model_handler, + model_handler, + attention_mode=attention_mode, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), size=model_handler.size @@ -262,27 +394,100 @@ class VibeVoiceTTSNode: set_vibevoice_seed(seed) - inputs = processor( - text=[full_script], voice_samples=[voice_samples_np], padding=True, - return_tensors="pt", return_attention_mask=True - ) - inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} - - model.set_ddpm_inference_steps(num_steps=inference_steps) - - generation_config = {'do_sample': do_sample} - if do_sample: - generation_config['temperature'] = temperature - generation_config['top_p'] = top_p - if top_k > 0: - generation_config['top_k'] = top_k - - with torch.no_grad(): - outputs = model.generate( - **inputs, max_new_tokens=None, cfg_scale=cfg_scale, - tokenizer=processor.tokenizer, generation_config=generation_config, - verbose=False + try: + inputs = processor( + text=[full_script], voice_samples=[voice_samples_np], padding=True, + return_tensors="pt", return_attention_mask=True ) + + # Validate inputs before moving to GPU + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + if torch.any(torch.isnan(value)) or torch.any(torch.isinf(value)): + logger.error(f"Input tensor '{key}' contains NaN or Inf values") + raise ValueError(f"Invalid values in input tensor: {key}") + + inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} + + model.set_ddpm_inference_steps(num_steps=inference_steps) + + generation_config = {'do_sample': do_sample} + if do_sample: + generation_config['temperature'] = temperature + generation_config['top_p'] = top_p + if top_k > 0: + generation_config['top_k'] = top_k + + # Hardware-specific optimizations - only for eager mode + if attention_mode == "eager": + # Apply RTX 5090 / Blackwell compatibility fixes only for eager + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.cuda.empty_cache() + + # Apply additional tensor fixes for eager mode + model = model.float() + processed_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + # Keep integer/boolean tensors as-is (token IDs, attention masks, etc.) + if v.dtype in [torch.int, torch.long, torch.int32, torch.int64, torch.bool, torch.uint8]: + processed_inputs[k] = v + # Keep tensors with "mask" in their name as boolean + elif "mask" in k.lower(): + processed_inputs[k] = v.bool() if v.dtype != torch.bool else v + else: + # Convert float/bfloat16 tensors to float32 + processed_inputs[k] = v.float() + else: + processed_inputs[k] = v + inputs = processed_inputs + + with torch.no_grad(): + # Create progress bar for inference steps + pbar = ProgressBar(inference_steps) + + def progress_callback(step, total_steps): + pbar.update(1) + # Check for interruption from ComfyUI + if model_management.interrupt_current_processing: + raise comfy.model_management.InterruptProcessingException() + + # Custom generation loop with interruption support + try: + outputs = model.generate( + **inputs, max_new_tokens=None, cfg_scale=cfg_scale, + tokenizer=processor.tokenizer, generation_config=generation_config, + verbose=False + ) + # Note: The model.generate method doesn't support progress callbacks in the current VibeVoice implementation + # But we check for interruption at the start and end of generation + pbar.update(inference_steps - pbar.current) + + except RuntimeError as e: + error_msg = str(e).lower() + if "assertion" in error_msg or "cuda" in error_msg: + logger.error(f"CUDA assertion failed with {attention_mode} attention: {e}") + logger.error("This might be due to invalid input data, GPU memory issues, or incompatible attention mode.") + logger.error("Try restarting ComfyUI, using different audio files, or switching to 'eager' attention mode.") + raise e + except comfy.model_management.InterruptProcessingException: + logger.info("VibeVoice generation interrupted by user") + raise + finally: + pbar.update_absolute(inference_steps) + + except comfy.model_management.InterruptProcessingException: + logger.info("VibeVoice TTS generation was cancelled") + # Return silent audio on cancellation + return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) + + except Exception as e: + logger.error(f"Error during VibeVoice generation with {attention_mode} attention: {e}") + if "interrupt" in str(e).lower() or "cancel" in str(e).lower(): + logger.info("Generation was interrupted") + return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) + raise output_waveform = outputs.speech_outputs[0] if output_waveform.ndim == 1: output_waveform = output_waveform.unsqueeze(0) @@ -291,4 +496,4 @@ class VibeVoiceTTSNode: return ({"waveform": output_waveform.detach().cpu(), "sample_rate": 24000},) NODE_CLASS_MAPPINGS = {"VibeVoiceTTS": VibeVoiceTTSNode} -NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"} \ No newline at end of file +NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"} From 33bc1843b9f3962cdd7892979c4821f567a5ead4 Mon Sep 17 00:00:00 2001 From: drbaph <84208527+Saganaki22@users.noreply.github.com> Date: Thu, 28 Aug 2025 02:28:56 +0100 Subject: [PATCH 2/2] Fix memory leaks and ComfyUI model management compatibility - Fixed IndexError in ComfyUI's model management system when unloading models - Improved memory cleanup to prevent VRAM leaks when switching between models - Updated cache key handling to properly track attention mode variants - Enhanced patcher lifecycle management to work with ComfyUI's internal systems - Added safer model cleanup that doesn't interfere with ComfyUI's model tracking --- vibevoice_nodes.py | 60 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/vibevoice_nodes.py b/vibevoice_nodes.py index 25e1aac..29cbbb3 100644 --- a/vibevoice_nodes.py +++ b/vibevoice_nodes.py @@ -6,6 +6,7 @@ import random from huggingface_hub import snapshot_download import logging import librosa +import gc import folder_paths import comfy.model_management as model_management @@ -35,11 +36,45 @@ MODEL_CONFIGS = { ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"] +def cleanup_old_models(keep_cache_key=None): + """Clean up old models, optionally keeping one specific model loaded""" + global LOADED_MODELS, VIBEVOICE_PATCHER_CACHE + + keys_to_remove = [] + + # Clear LOADED_MODELS + for key in list(LOADED_MODELS.keys()): + if key != keep_cache_key: + keys_to_remove.append(key) + del LOADED_MODELS[key] + + # Clear VIBEVOICE_PATCHER_CACHE - but more carefully + for key in list(VIBEVOICE_PATCHER_CACHE.keys()): + if key != keep_cache_key: + # Set the model/processor to None but don't delete the patcher itself + # This lets ComfyUI's model management handle the patcher cleanup + try: + patcher = VIBEVOICE_PATCHER_CACHE[key] + if hasattr(patcher, 'model') and patcher.model: + patcher.model.model = None + patcher.model.processor = None + # Remove from our cache but let ComfyUI handle the rest + del VIBEVOICE_PATCHER_CACHE[key] + except Exception as e: + logger.warning(f"Error cleaning up patcher {key}: {e}") + + if keys_to_remove: + logger.info(f"Cleaned up cached models: {keys_to_remove}") + gc.collect() + model_management.soft_empty_cache() + class VibeVoiceModelHandler(torch.nn.Module): """A torch.nn.Module wrapper to hold the VibeVoice model and processor.""" - def __init__(self, model_pack_name): + def __init__(self, model_pack_name, attention_mode="eager"): super().__init__() self.model_pack_name = model_pack_name + self.attention_mode = attention_mode + self.cache_key = f"{model_pack_name}_attn_{attention_mode}" self.model = None self.processor = None self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3)) @@ -53,6 +88,7 @@ class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): def __init__(self, model, attention_mode="eager", *args, **kwargs): super().__init__(model, *args, **kwargs) self.attention_mode = attention_mode + self.cache_key = model.cache_key def patch_model(self, device_to=None, *args, **kwargs): target_device = self.load_device @@ -70,12 +106,22 @@ class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs): if unpatch_weights: - logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' to {device_to}...") + logger.info(f"Offloading VibeVoice models for '{self.model.model_pack_name}' ({self.attention_mode}) to {device_to}...") self.model.model = None self.model.processor = None - if self.model.model_pack_name in LOADED_MODELS: - del LOADED_MODELS[self.model.model_pack_name] + + # Clear using the correct cache key + if self.cache_key in LOADED_MODELS: + del LOADED_MODELS[self.cache_key] + logger.info(f"Cleared LOADED_MODELS cache for: {self.cache_key}") + + # DON'T delete from VIBEVOICE_PATCHER_CACHE here - let ComfyUI handle it + # This prevents the IndexError in ComfyUI's model management + + # Force garbage collection + gc.collect() model_management.soft_empty_cache() + return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs) class VibeVoiceLoader: @@ -360,8 +406,12 @@ class VibeVoiceTTSNode: # Create cache key that includes attention mode cache_key = f"{model_name}_attn_{attention_mode}" + # Clean up old models when switching to a different model if cache_key not in VIBEVOICE_PATCHER_CACHE: - model_handler = VibeVoiceModelHandler(model_name) + # Only keep models that are currently being requested + cleanup_old_models(keep_cache_key=cache_key) + + model_handler = VibeVoiceModelHandler(model_name, attention_mode) patcher = VibeVoicePatcher( model_handler, attention_mode=attention_mode,