From c29daa905076e8913c3703d5f150c9fecd2a91fc Mon Sep 17 00:00:00 2001 From: drbaph <84208527+Saganaki22@users.noreply.github.com> Date: Thu, 28 Aug 2025 01:51:29 +0100 Subject: [PATCH] Add configurable attention modes with compatibility checks - Added dropdown selection for attention implementation (eager/sdpa/flash_attention_2) - Implemented automatic compatibility checks and progressive fallbacks - Added hardware-specific optimizations for RTX 5090/Blackwell GPUs - Enhanced error handling to prevent crashes from incompatible attention modes --- vibevoice_nodes.py | 285 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 245 insertions(+), 40 deletions(-) diff --git a/vibevoice_nodes.py b/vibevoice_nodes.py index 540ea0b..25e1aac 100644 --- a/vibevoice_nodes.py +++ b/vibevoice_nodes.py @@ -33,6 +33,8 @@ MODEL_CONFIGS = { } } +ATTENTION_MODES = ["eager", "sdpa", "flash_attention_2"] + class VibeVoiceModelHandler(torch.nn.Module): """A torch.nn.Module wrapper to hold the VibeVoice model and processor.""" def __init__(self, model_pack_name): @@ -42,20 +44,27 @@ class VibeVoiceModelHandler(torch.nn.Module): self.processor = None self.size = int(MODEL_CONFIGS[model_pack_name].get("size_gb", 4.0) * (1024**3)) - def load_model(self, device): - self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name) + def load_model(self, device, attention_mode="eager"): + self.model, self.processor = VibeVoiceLoader.load_model(self.model_pack_name, attention_mode) self.model.to(device) class VibeVoicePatcher(comfy.model_patcher.ModelPatcher): """Custom ModelPatcher for managing VibeVoice models in ComfyUI.""" - def __init__(self, model, *args, **kwargs): + def __init__(self, model, attention_mode="eager", *args, **kwargs): super().__init__(model, *args, **kwargs) + self.attention_mode = attention_mode def patch_model(self, device_to=None, *args, **kwargs): target_device = self.load_device if self.model.model is None: logger.info(f"Loading VibeVoice models for '{self.model.model_pack_name}' to {target_device}...") - self.model.load_model(target_device) + mode_names = { + "eager": "Eager (Most Compatible)", + "sdpa": "SDPA (Balanced Speed/Compatibility)", + "flash_attention_2": "Flash Attention 2 (Fastest)" + } + logger.info(f"Attention Mode: {mode_names.get(self.attention_mode, self.attention_mode)}") + self.model.load_model(target_device, self.attention_mode) self.model.model.to(target_device) return super().patch_model(device_to=target_device, *args, **kwargs) @@ -86,9 +95,47 @@ class VibeVoiceLoader: return model_path @staticmethod - def load_model(model_name: str): - if model_name in LOADED_MODELS: - return LOADED_MODELS[model_name] + def _check_attention_compatibility(attention_mode: str, torch_dtype, device_name: str = ""): + """Check if the requested attention mode is compatible with current setup.""" + + # Check for SDPA availability (PyTorch 2.0+) + if attention_mode == "sdpa": + if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): + logger.warning("SDPA not available (requires PyTorch 2.0+), falling back to eager") + return "eager" + + # Check for Flash Attention availability + elif attention_mode == "flash_attention_2": + if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): + logger.warning("Flash Attention not available, falling back to eager") + return "eager" + elif torch_dtype == torch.float32: + logger.warning("Flash Attention not recommended with float32, falling back to SDPA") + return "sdpa" if hasattr(torch.nn.functional, 'scaled_dot_product_attention') else "eager" + + # Just informational messages, no forced fallbacks + if device_name and torch.cuda.is_available(): + if "RTX 50" in device_name or "Blackwell" in device_name: + if attention_mode == "flash_attention_2": + logger.info(f"Using Flash Attention on {device_name}") + elif attention_mode == "sdpa": + logger.info(f"Using SDPA on {device_name}") + + return attention_mode + + @staticmethod + def load_model(model_name: str, attention_mode: str = "eager"): + # Validate attention mode + if attention_mode not in ATTENTION_MODES: + logger.warning(f"Unknown attention mode '{attention_mode}', falling back to eager") + attention_mode = "eager" + + # Create cache key that includes attention mode + cache_key = f"{model_name}_attn_{attention_mode}" + + if cache_key in LOADED_MODELS: + logger.info(f"Using cached model with {attention_mode} attention") + return LOADED_MODELS[cache_key] model_path = VibeVoiceLoader.get_model_path(model_name) @@ -96,16 +143,74 @@ class VibeVoiceLoader: processor = VibeVoiceProcessor.from_pretrained(model_path) torch_dtype = model_management.text_encoder_dtype(model_management.get_torch_device()) - - model = VibeVoiceForConditionalGenerationInference.from_pretrained( - model_path, - torch_dtype=torch_dtype, - attn_implementation="flash_attention_2" if hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch_dtype != torch.float32 else "eager", - ) - model.eval() + device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "" - LOADED_MODELS[model_name] = (model, processor) - return model, processor + # Check compatibility and potentially fall back to safer mode + final_attention_mode = VibeVoiceLoader._check_attention_compatibility( + attention_mode, torch_dtype, device_name + ) + + print(f"Requested attention mode: {attention_mode}") + if final_attention_mode != attention_mode: + print(f"Using attention mode: {final_attention_mode} (automatic fallback)") + # Update cache key to reflect actual mode used + cache_key = f"{model_name}_attn_{final_attention_mode}" + if cache_key in LOADED_MODELS: + return LOADED_MODELS[cache_key] + else: + print(f"Using attention mode: {final_attention_mode}") + + logger.info(f"Final attention implementation: {final_attention_mode}") + + # Modify config for non-flash attention modes + if final_attention_mode in ["eager", "sdpa"]: + import json + config_path = os.path.join(model_path, "config.json") + if os.path.exists(config_path): + try: + with open(config_path, 'r') as f: + config = json.load(f) + + # Remove flash attention settings + removed_keys = [] + for key in ['_attn_implementation', 'attn_implementation', 'use_flash_attention_2']: + if key in config: + config.pop(key) + removed_keys.append(key) + + if removed_keys: + with open(config_path, 'w') as f: + json.dump(config, f, indent=2) + logger.info(f"Removed FlashAttention settings from config.json: {removed_keys}") + except Exception as e: + logger.warning(f"Could not modify config.json: {e}") + + try: + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + model_path, + torch_dtype=torch_dtype, + attn_implementation=final_attention_mode, + ) + model.eval() + + # Store with the actual attention mode used (not the requested one) + LOADED_MODELS[cache_key] = (model, processor) + logger.info(f"Successfully loaded model with {final_attention_mode} attention") + return model, processor + + except Exception as e: + logger.error(f"Failed to load model with {final_attention_mode} attention: {e}") + + # Progressive fallback: flash -> sdpa -> eager + if final_attention_mode == "flash_attention_2": + logger.info("Attempting fallback to SDPA...") + return VibeVoiceLoader.load_model(model_name, "sdpa") + elif final_attention_mode == "sdpa": + logger.info("Attempting fallback to eager...") + return VibeVoiceLoader.load_model(model_name, "eager") + else: + # If eager fails, something is seriously wrong + raise RuntimeError(f"Failed to load model even with eager attention: {e}") def set_vibevoice_seed(seed: int): @@ -162,9 +267,29 @@ def preprocess_comfy_audio(audio_dict: dict, target_sr: int = 24000) -> np.ndarr if waveform.ndim > 1: waveform = np.mean(waveform, axis=0) + # Check for invalid values + if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): + logger.error("Audio contains NaN or Inf values, replacing with zeros") + waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) + + # Ensure audio is not completely silent or has extreme values + if np.all(waveform == 0): + logger.warning("Audio waveform is completely silent") + + # Normalize extreme values + max_val = np.abs(waveform).max() + if max_val > 10.0: + logger.warning(f"Audio values are very large (max: {max_val}), normalizing") + waveform = waveform / max_val + if original_sr != target_sr: logger.warning(f"Resampling reference audio from {original_sr}Hz to {target_sr}Hz.") waveform = librosa.resample(y=waveform, orig_sr=original_sr, target_sr=target_sr) + + # Final check after resampling + if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): + logger.error("Audio contains NaN or Inf after resampling, replacing with zeros") + waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) return waveform.astype(np.float32) @@ -182,6 +307,10 @@ class VibeVoiceTTSNode: "default": "Speaker 1: Hello from ComfyUI!\nSpeaker 2: VibeVoice sounds amazing.", "tooltip": "The script for the conversation. Use 'Speaker 1:', 'Speaker 2:', etc. to assign lines to different voices. Each speaker line should be on a new line." }), + "attention_mode": (["eager", "sdpa", "flash_attention_2"], { + "default": "sdpa", + "tooltip": "Attention implementation: Eager (safest), SDPA (balanced), Flash Attention 2 (fastest but may cause issues on some GPUs like RTX 5090)" + }), "cfg_scale": ("FLOAT", { "default": 1.3, "min": 1.0, "max": 2.0, "step": 0.05, "tooltip": "Classifier-Free Guidance scale. Higher values increase adherence to the voice prompt but may reduce naturalness. Recommended: 1.3" @@ -223,16 +352,19 @@ class VibeVoiceTTSNode: FUNCTION = "generate_audio" CATEGORY = "audio/tts" - def generate_audio(self, model_name, text, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs): + def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, **kwargs): if not text.strip(): logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.") return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) - cache_key = model_name + # Create cache key that includes attention mode + cache_key = f"{model_name}_attn_{attention_mode}" + if cache_key not in VIBEVOICE_PATCHER_CACHE: model_handler = VibeVoiceModelHandler(model_name) patcher = VibeVoicePatcher( - model_handler, + model_handler, + attention_mode=attention_mode, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), size=model_handler.size @@ -262,27 +394,100 @@ class VibeVoiceTTSNode: set_vibevoice_seed(seed) - inputs = processor( - text=[full_script], voice_samples=[voice_samples_np], padding=True, - return_tensors="pt", return_attention_mask=True - ) - inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} - - model.set_ddpm_inference_steps(num_steps=inference_steps) - - generation_config = {'do_sample': do_sample} - if do_sample: - generation_config['temperature'] = temperature - generation_config['top_p'] = top_p - if top_k > 0: - generation_config['top_k'] = top_k - - with torch.no_grad(): - outputs = model.generate( - **inputs, max_new_tokens=None, cfg_scale=cfg_scale, - tokenizer=processor.tokenizer, generation_config=generation_config, - verbose=False + try: + inputs = processor( + text=[full_script], voice_samples=[voice_samples_np], padding=True, + return_tensors="pt", return_attention_mask=True ) + + # Validate inputs before moving to GPU + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + if torch.any(torch.isnan(value)) or torch.any(torch.isinf(value)): + logger.error(f"Input tensor '{key}' contains NaN or Inf values") + raise ValueError(f"Invalid values in input tensor: {key}") + + inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} + + model.set_ddpm_inference_steps(num_steps=inference_steps) + + generation_config = {'do_sample': do_sample} + if do_sample: + generation_config['temperature'] = temperature + generation_config['top_p'] = top_p + if top_k > 0: + generation_config['top_k'] = top_k + + # Hardware-specific optimizations - only for eager mode + if attention_mode == "eager": + # Apply RTX 5090 / Blackwell compatibility fixes only for eager + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.cuda.empty_cache() + + # Apply additional tensor fixes for eager mode + model = model.float() + processed_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + # Keep integer/boolean tensors as-is (token IDs, attention masks, etc.) + if v.dtype in [torch.int, torch.long, torch.int32, torch.int64, torch.bool, torch.uint8]: + processed_inputs[k] = v + # Keep tensors with "mask" in their name as boolean + elif "mask" in k.lower(): + processed_inputs[k] = v.bool() if v.dtype != torch.bool else v + else: + # Convert float/bfloat16 tensors to float32 + processed_inputs[k] = v.float() + else: + processed_inputs[k] = v + inputs = processed_inputs + + with torch.no_grad(): + # Create progress bar for inference steps + pbar = ProgressBar(inference_steps) + + def progress_callback(step, total_steps): + pbar.update(1) + # Check for interruption from ComfyUI + if model_management.interrupt_current_processing: + raise comfy.model_management.InterruptProcessingException() + + # Custom generation loop with interruption support + try: + outputs = model.generate( + **inputs, max_new_tokens=None, cfg_scale=cfg_scale, + tokenizer=processor.tokenizer, generation_config=generation_config, + verbose=False + ) + # Note: The model.generate method doesn't support progress callbacks in the current VibeVoice implementation + # But we check for interruption at the start and end of generation + pbar.update(inference_steps - pbar.current) + + except RuntimeError as e: + error_msg = str(e).lower() + if "assertion" in error_msg or "cuda" in error_msg: + logger.error(f"CUDA assertion failed with {attention_mode} attention: {e}") + logger.error("This might be due to invalid input data, GPU memory issues, or incompatible attention mode.") + logger.error("Try restarting ComfyUI, using different audio files, or switching to 'eager' attention mode.") + raise e + except comfy.model_management.InterruptProcessingException: + logger.info("VibeVoice generation interrupted by user") + raise + finally: + pbar.update_absolute(inference_steps) + + except comfy.model_management.InterruptProcessingException: + logger.info("VibeVoice TTS generation was cancelled") + # Return silent audio on cancellation + return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) + + except Exception as e: + logger.error(f"Error during VibeVoice generation with {attention_mode} attention: {e}") + if "interrupt" in str(e).lower() or "cancel" in str(e).lower(): + logger.info("Generation was interrupted") + return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},) + raise output_waveform = outputs.speech_outputs[0] if output_waveform.ndim == 1: output_waveform = output_waveform.unsqueeze(0) @@ -291,4 +496,4 @@ class VibeVoiceTTSNode: return ({"waveform": output_waveform.detach().cpu(), "sample_rate": 24000},) NODE_CLASS_MAPPINGS = {"VibeVoiceTTS": VibeVoiceTTSNode} -NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"} \ No newline at end of file +NODE_DISPLAY_NAME_MAPPINGS = {"VibeVoiceTTS": "VibeVoice TTS"}