mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-01-26 14:39:45 +00:00
fix dtype issue
This commit is contained in:
@@ -14,6 +14,13 @@ import comfy.model_patcher
|
|||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
from comfy.model_management import throw_exception_if_processing_interrupted
|
from comfy.model_management import throw_exception_if_processing_interrupted
|
||||||
|
|
||||||
|
# Import transformers and packaging to handle different library versions.
|
||||||
|
import transformers
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
_transformers_version = version.parse(transformers.__version__)
|
||||||
|
_DTYPE_ARG_SUPPORTED = _transformers_version >= version.parse("4.56.0")
|
||||||
|
|
||||||
from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
|
from transformers import set_seed, AutoTokenizer, BitsAndBytesConfig
|
||||||
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
from .vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
||||||
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
from .vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
||||||
@@ -229,12 +236,21 @@ class VibeVoiceLoader:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading model with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'")
|
logger.info(f"Loading model with dtype: {final_load_dtype} and attention: '{attn_implementation_for_load}'")
|
||||||
|
# Build a dictionary of keyword arguments for from_pretrained.
|
||||||
|
from_pretrained_kwargs = {
|
||||||
|
"attn_implementation": attn_implementation_for_load,
|
||||||
|
"device_map": "auto" if quant_config else device,
|
||||||
|
"quantization_config": quant_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use the correct dtype argument based on the transformers version.
|
||||||
|
if _DTYPE_ARG_SUPPORTED:
|
||||||
|
from_pretrained_kwargs['dtype'] = final_load_dtype
|
||||||
|
else:
|
||||||
|
from_pretrained_kwargs['torch_dtype'] = final_load_dtype
|
||||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
dtype=final_load_dtype,
|
**from_pretrained_kwargs
|
||||||
attn_implementation=attn_implementation_for_load,
|
|
||||||
device_map="auto" if quant_config else device,
|
|
||||||
quantization_config=quant_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_mode == "sage":
|
if attention_mode == "sage":
|
||||||
|
|||||||
Reference in New Issue
Block a user