diff --git a/backend/text_processing/classic_engine.py b/backend/text_processing/classic_engine.py index 3832fe65..b5d192fc 100644 --- a/backend/text_processing/classic_engine.py +++ b/backend/text_processing/classic_engine.py @@ -4,6 +4,8 @@ import torch from collections import namedtuple from backend.text_processing import parsing, emphasis from backend.text_processing.textual_inversion import EmbeddingDatabase +from backend import memory_management + PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) last_extra_generation_params = {} @@ -118,7 +120,9 @@ class ClassicTextProcessingEngine: return tokenized def encode_with_transformers(self, tokens): - tokens = tokens.to(self.text_encoder.transformer.text_model.embeddings.token_embedding.weight.device) + target_device = memory_management.get_torch_device() + self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device) + tokens = tokens.to(target_device) outputs = self.text_encoder.transformer(tokens, output_hidden_states=True)