mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-04-27 10:13:43 +00:00
Merge pull request #16 from Saganaki22/main
Transformers 4.56+ Compatibility & Force Offload Fix
This commit is contained in:
28
README.md
28
README.md
@@ -36,6 +36,8 @@ The custom node handles everything from model downloading and memory management
|
||||
* **Automatic Model Management:** Models are downloaded automatically from Hugging Face and managed efficiently by ComfyUI to save VRAM.
|
||||
* **Fine-Grained Control:** Adjust parameters like CFG scale, temperature, and sampling methods to tune the performance and style of the generated speech.
|
||||
* **4-Bit Quantization:** Run the large language model component in 4-bit mode to significantly reduce VRAM usage and improve speed on memory-constrained GPUs, especially for the 7B model.
|
||||
* **Transformers 4.56+ Compatibility:** Fully backwards compatible with both older and newer versions of the Transformers library.
|
||||
* **Force Offload Option:** Toggle to force model offloading from VRAM after generation to save memory between runs - now with improved ComfyUI compatibility.
|
||||
|
||||
<p align="right">(<a href="#readme-top">back to top</a>)</p>
|
||||
|
||||
@@ -92,6 +94,7 @@ _For a complete workflow, you can drag the example image from the `example_workf
|
||||
* **`inference_steps`**: Number of diffusion steps for the audio decoder.
|
||||
* **`seed`**: A seed for reproducibility.
|
||||
* **`do_sample`, `temperature`, `top_p`, `top_k`**: Standard sampling parameters for controlling the creativity and determinism of the speech generation.
|
||||
* **`force_offload`**: (New!) Forces the model to be completely offloaded from VRAM after generation. Useful for memory management but may slow down subsequent runs.
|
||||
* **`speaker_*_voice` (Optional)**: Connect an `AUDIO` output from a `Load Audio` node to provide a voice reference.
|
||||
|
||||
### Performance & Quantization
|
||||
@@ -107,6 +110,16 @@ A key feature of this node is the optional **4-bit quantization** for the langua
|
||||
|
||||
As shown, quantization provides a massive speedup and VRAM reduction for the 7B model, making it accessible on a wider range of hardware. While it slightly slows down the 1.5B model, the significant VRAM savings may still be beneficial for complex workflows.
|
||||
|
||||
### Transformers Library Compatibility
|
||||
|
||||
This version includes automatic detection and compatibility for both older and newer versions of the Transformers library:
|
||||
|
||||
* **Transformers 4.56+**: Automatically uses the new method signature for `_prepare_cache_for_generation`
|
||||
* **Older Versions**: Maintains compatibility with pre-4.56 versions using the legacy method signature
|
||||
* **Fallback Mechanism**: If detection fails, the node will automatically try both versions to ensure maximum compatibility
|
||||
|
||||
This ensures the node works seamlessly regardless of your Transformers version without requiring manual updates.
|
||||
|
||||
### Tips from the Original Authors
|
||||
|
||||
* **Punctuation:** For Chinese text, using English punctuation (commas and periods) can improve stability.
|
||||
@@ -116,6 +129,21 @@ As shown, quantization provides a massive speedup and VRAM reduction for the 7B
|
||||
|
||||
<p align="right">(<a href="#readme-top">back to top</a>)</p>
|
||||
|
||||
<!-- BUG FIXES -->
|
||||
## Recent Bug Fixes
|
||||
|
||||
### Force Offload Compatibility Fix
|
||||
* **Fixed:** Resolved `AttributeError: module 'comfy.model_management' has no attribute 'unload_model_clones'` error when using the force offload option
|
||||
* **Details:** Updated the force offload implementation to use ComfyUI's standard `unload_all_models()` API instead of the deprecated `unload_model_clones()` function
|
||||
* **Impact:** Force offload functionality now works correctly with all versions of ComfyUI
|
||||
|
||||
### Multi-Speaker DynamicCache Fix
|
||||
* **Fixed:** Resolved `'DynamicCache' object has no attribute 'key_cache'` error when using multiple speakers
|
||||
* **Details:** Updated cache access in `modeling_vibevoice_inference.py` to use proper DynamicCache API - accessing layers via indexing instead of deprecated `.key_cache` and `.value_cache` attributes
|
||||
* **Impact:** Multi-speaker functionality now works correctly with newer versions of Transformers library
|
||||
|
||||
<p align="right">(<a href="#readme-top">back to top</a>)</p>
|
||||
|
||||
<!-- LICENSE -->
|
||||
## License
|
||||
|
||||
|
||||
@@ -300,7 +300,23 @@ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, Gener
|
||||
)
|
||||
|
||||
max_cache_length = generation_config.max_length - 1
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
# Backwards compatible fix for _prepare_cache_for_generation method signature
|
||||
# New transformers version expects 5 args, old version expects 6
|
||||
import inspect
|
||||
try:
|
||||
sig = inspect.signature(self._prepare_cache_for_generation)
|
||||
if len(sig.parameters) == 5:
|
||||
# New transformers version (4.56+)
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, batch_size, max_cache_length, device)
|
||||
else:
|
||||
# Old transformers version (pre-4.56)
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
except Exception as e:
|
||||
# Fallback to try both versions
|
||||
try:
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, batch_size, max_cache_length, device)
|
||||
except TypeError:
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
|
||||
for k, v in model_kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
@@ -551,8 +567,8 @@ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, Gener
|
||||
negative_model_kwargs['attention_mask'][sample_idx, :] = 0
|
||||
negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
|
||||
# update past key values
|
||||
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
|
||||
negative_model_kwargs['past_key_values'].value_cache)):
|
||||
for layer_idx in range(len(negative_model_kwargs['past_key_values'])):
|
||||
k_cache, v_cache = negative_model_kwargs['past_key_values'][layer_idx]
|
||||
# Process each non-diffusion sample
|
||||
for sample_idx in diffusion_start_indices.tolist():
|
||||
# Shift cache for this sample
|
||||
@@ -604,8 +620,8 @@ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, Gener
|
||||
negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
|
||||
|
||||
# 2. Update past_key_values
|
||||
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
|
||||
negative_model_kwargs['past_key_values'].value_cache)):
|
||||
for layer_idx in range(len(negative_model_kwargs['past_key_values'])):
|
||||
k_cache, v_cache = negative_model_kwargs['past_key_values'][layer_idx]
|
||||
# Process each non-diffusion sample
|
||||
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
|
||||
if start_idx + 1 < k_cache.shape[2] - 1:
|
||||
|
||||
@@ -100,6 +100,11 @@ class VibeVoicePatcher(comfy.model_patcher.ModelPatcher):
|
||||
super().__init__(model, *args, **kwargs)
|
||||
self.attention_mode = attention_mode
|
||||
self.cache_key = model.cache_key
|
||||
|
||||
@property
|
||||
def is_loaded(self):
|
||||
"""Check if the model is currently loaded in memory."""
|
||||
return hasattr(self, 'model') and self.model is not None and hasattr(self.model, 'model') and self.model.model is not None
|
||||
|
||||
def patch_model(self, device_to=None, *args, **kwargs):
|
||||
target_device = self.load_device
|
||||
@@ -432,6 +437,10 @@ class VibeVoiceTTSNode:
|
||||
"default": 0, "min": 0, "max": 500, "step": 1,
|
||||
"tooltip": "Top-K sampling. Restricts sampling to the K most likely next tokens. Set to 0 to disable. Active only if 'do_sample' is enabled."
|
||||
}),
|
||||
"force_offload": ("BOOLEAN", {
|
||||
"default": False, "label_on": "Force Offload", "label_off": "Keep in VRAM",
|
||||
"tooltip": "Force model to be offloaded from VRAM after generation. Useful to free up memory between generations but may slow down subsequent runs."
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"speaker_1_voice": ("AUDIO", {"tooltip": "Reference audio for 'Speaker 1' in the script."}),
|
||||
@@ -445,7 +454,7 @@ class VibeVoiceTTSNode:
|
||||
FUNCTION = "generate_audio"
|
||||
CATEGORY = "audio/tts"
|
||||
|
||||
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, **kwargs):
|
||||
def generate_audio(self, model_name, text, attention_mode, cfg_scale, inference_steps, seed, do_sample, temperature, top_p, top_k, quantize_llm_4bit, force_offload, **kwargs):
|
||||
if not text.strip():
|
||||
logger.warning("VibeVoiceTTS: Empty text provided, returning silent audio.")
|
||||
return ({"waveform": torch.zeros((1, 1, 24000), dtype=torch.float32), "sample_rate": 24000},)
|
||||
@@ -589,6 +598,18 @@ class VibeVoiceTTSNode:
|
||||
output_waveform = outputs.speech_outputs[0]
|
||||
if output_waveform.ndim == 1: output_waveform = output_waveform.unsqueeze(0)
|
||||
if output_waveform.ndim == 2: output_waveform = output_waveform.unsqueeze(0)
|
||||
|
||||
# Force offload model if requested
|
||||
if force_offload:
|
||||
logger.info(f"Force offloading VibeVoice model '{model_name}' from VRAM...")
|
||||
# Force offload by unpatching the model and freeing memory
|
||||
if patcher.is_loaded:
|
||||
patcher.unpatch_model(unpatch_weights=True)
|
||||
# Force unload all models to free memory
|
||||
model_management.unload_all_models()
|
||||
gc.collect()
|
||||
model_management.soft_empty_cache()
|
||||
logger.info("Model force offload completed")
|
||||
|
||||
return ({"waveform": output_waveform.detach().cpu(), "sample_rate": 24000},)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user