diff --git a/backend/text_processing/classic_engine.py b/backend/text_processing/classic_engine.py index b5d192fc..14a27ca3 100644 --- a/backend/text_processing/classic_engine.py +++ b/backend/text_processing/classic_engine.py @@ -120,7 +120,7 @@ class ClassicTextProcessingEngine: return tokenized def encode_with_transformers(self, tokens): - target_device = memory_management.get_torch_device() + target_device = self.text_encoder.transformer.text_model.embeddings.token_embedding.weight.device self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device) tokens = tokens.to(target_device)