mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-01-26 14:39:45 +00:00
fix dtype issue
This commit is contained in:
@@ -14,6 +14,13 @@ import comfy.model_patcher
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy.model_management import throw_exception_if_processing_interrupted
|
||||
|
||||
# Import transformers and packaging to handle different library versions.
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
_transformers_version = version.parse(transformers.__version__)
|
||||
_DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0")
|
||||
|
||||
from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
|
||||
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
||||
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
||||
@@ -229,12 +236,21 @@ class VibeVoiceLoader:
|
||||
|
||||
try:
|
||||
logger.info(f"Loading model with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'")
|
||||
# Build a dictionary of keyword arguments for from_pretrained.
|
||||
from_pretrained_kwargs = {
|
||||
"attn_implementation": attn_implementation_for_load,
|
||||
"device_map": "auto" if quant_config else device,
|
||||
"quantization_config": quant_config,
|
||||
}
|
||||
|
||||
# Use the correct dtype argument based on the transformers version.
|
||||
if _DTYPE_ARG_SUPPORTED:
|
||||
from_pretrained_kwargs['dtype'] = final_load_dtype
|
||||
else:
|
||||
from_pretrained_kwargs['torch_dtype'] = final_load_dtype
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
model_path,
|
||||
dtype=final_load_dtype,
|
||||
attn_implementation=attn_implementation_for_load,
|
||||
device_map="auto" if quant_config else device,
|
||||
quantization_config=quant_config,
|
||||
**from_pretrained_kwargs
|
||||
)
|
||||
|
||||
if attention_mode == "sage":
|
||||
|
||||
Reference in New Issue
Block a user