improve clip cast

This commit is contained in:
layerdiffusion
2024-08-06 21:09:40 -07:00
parent 0128ae6041
commit 64baac36b6

View File

@@ -4,6 +4,8 @@ import torch
from collections import namedtuple
from backend.text_processing import parsing, emphasis
from backend.text_processing.textual_inversion import EmbeddingDatabase
from backend import memory_management
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
last_extra_generation_params = {}
@@ -118,7 +120,9 @@ class ClassicTextProcessingEngine:
return tokenized
def encode_with_transformers(self, tokens):
tokens = tokens.to(self.text_encoder.transformer.text_model.embeddings.token_embedding.weight.device)
target_device = memory_management.get_torch_device()
self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device)
tokens = tokens.to(target_device)
outputs = self.text_encoder.transformer(tokens, output_hidden_states=True)