fix text encoder dtype

This commit is contained in:
layerdiffusion
2024-08-09 15:11:07 -07:00
parent dad1d17f15
commit 4014013d05

View File

@@ -51,7 +51,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
from transformers import CLIPTextConfig, CLIPTextModel
config = CLIPTextConfig.from_pretrained(config_path)
to_args = dict(device=memory_management.text_encoder_device(), dtype=memory_management.text_encoder_dtype())
to_args = dict(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype())
with modeling_utils.no_init_weights():
with using_forge_operations(**to_args, manual_cast_enabled=True):