diff --git a/backend/text_processing/t5_engine.py b/backend/text_processing/t5_engine.py index e6c8c55d..1d7065c6 100644 --- a/backend/text_processing/t5_engine.py +++ b/backend/text_processing/t5_engine.py @@ -58,10 +58,9 @@ class T5TextProcessingEngine: return tokenized def encode_with_transformers(self, tokens): - tokens = tokens.to(memory_management.get_torch_device()) device = memory_management.get_torch_device() - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 - self.text_encoder.shared.to(device=device, dtype=dtype) + tokens = tokens.to(device) + self.text_encoder.shared.to(device=device, dtype=torch.float32) z = self.text_encoder( input_ids=tokens,