make results more consistent to A1111

This commit is contained in:
layerdiffusion
2024-08-08 01:53:03 -07:00
parent e396307e9d
commit a05a06b337
3 changed files with 8 additions and 6 deletions

View File

@@ -120,8 +120,12 @@ class ClassicTextProcessingEngine:
return tokenized
def encode_with_transformers(self, tokens):
target_device = self.text_encoder.transformer.text_model.embeddings.token_embedding.weight.device
target_device = memory_management.text_encoder_device()
self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device)
self.text_encoder.transformer.text_model.embeddings.position_embedding = self.text_encoder.transformer.text_model.embeddings.position_embedding.to(dtype=torch.float32)
self.text_encoder.transformer.text_model.embeddings.token_embedding = self.text_encoder.transformer.text_model.embeddings.token_embedding.to(dtype=torch.float32)
tokens = tokens.to(target_device)
outputs = self.text_encoder.transformer(tokens, output_hidden_states=True)