From ce4a487379639ad26be049a937c01094b22b9902 Mon Sep 17 00:00:00 2001 From: WildAi <2853742+wildminder@users.noreply.github.com> Date: Wed, 3 Sep 2025 18:19:48 +0300 Subject: [PATCH] fix dtype issue --- vibevoice_nodes.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/vibevoice_nodes.py b/vibevoice_nodes.py index 5221619..641edc9 100644 --- a/vibevoice_nodes.py +++ b/vibevoice_nodes.py @@ -14,6 +14,13 @@ import comfy.model_patcher from comfy.utils import ProgressBar from comfy.model_management import throw_exception_if_processing_interrupted +# Import transformers and packaging to handle different library versions. +import transformers +from packaging import version + +_transformers_version = version.parse(transformers.__version__) +_DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0") + from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor @@ -229,12 +236,21 @@ class VibeVoiceLoader: try: logger.info(f"Loading model with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'") + # Build a dictionary of keyword arguments for from_pretrained. + from_pretrained_kwargs = { + "attn_implementation": attn_implementation_for_load, + "device_map": "auto" if quant_config else device, + "quantization_config": quant_config, + } + + # Use the correct dtype argument based on the transformers version. + if _DTYPE_ARG_SUPPORTED: + from_pretrained_kwargs['dtype'] = final_load_dtype + else: + from_pretrained_kwargs['torch_dtype'] = final_load_dtype model = VibeVoiceForConditionalGenerationInference.from_pretrained( model_path, - dtype=final_load_dtype, - attn_implementation=attn_implementation_for_load, - device_map="auto" if quant_config else device, - quantization_config=quant_config, + **from_pretrained_kwargs ) if attention_mode == "sage":