fix dtype issue

This commit is contained in:
WildAi
2025-09-03 18:19:48 +03:00
parent e42b4aa76e
commit ce4a487379

View File

@@ -14,6 +14,13 @@ import comfy.model_patcher
from comfy.utils import ProgressBar
from comfy.model_management import throw_exception_if_processing_interrupted
# Import transformers and packaging to handle different library versions.
import transformers
from packaging import version
_transformers_version = version.parse(transformers.__version__)
_DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0")
from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
@@ -229,12 +236,21 @@ class VibeVoiceLoader:
try:
logger.info(f"Loading model with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'")
# Build a dictionary of keyword arguments for from_pretrained.
from_pretrained_kwargs = {
"attn_implementation": attn_implementation_for_load,
"device_map": "auto" if quant_config else device,
"quantization_config": quant_config,
}
# Use the correct dtype argument based on the transformers version.
if _DTYPE_ARG_SUPPORTED:
from_pretrained_kwargs['dtype'] = final_load_dtype
else:
from_pretrained_kwargs['torch_dtype'] = final_load_dtype
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
model_path,
dtype=final_load_dtype,
attn_implementation=attn_implementation_for_load,
device_map="auto" if quant_config else device,
quantization_config=quant_config,
**from_pretrained_kwargs
)
if attention_mode == "sage":