From 78c65708ea405a0ff73cae006f4e8c6aad692cc6 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 7 Aug 2024 21:55:00 -0700 Subject: [PATCH] fix t5 --- backend/text_processing/t5_engine.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/backend/text_processing/t5_engine.py b/backend/text_processing/t5_engine.py index e6c8c55d..1d7065c6 100644 --- a/backend/text_processing/t5_engine.py +++ b/backend/text_processing/t5_engine.py @@ -58,10 +58,9 @@ class T5TextProcessingEngine: return tokenized def encode_with_transformers(self, tokens): - tokens = tokens.to(memory_management.get_torch_device()) device = memory_management.get_torch_device() - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 - self.text_encoder.shared.to(device=device, dtype=dtype) + tokens = tokens.to(device) + self.text_encoder.shared.to(device=device, dtype=torch.float32) z = self.text_encoder( input_ids=tokens,