This commit is contained in:
layerdiffusion
2024-08-07 21:55:00 -07:00
parent 463cff0d89
commit 78c65708ea

View File

@@ -58,10 +58,9 @@ class T5TextProcessingEngine:
return tokenized
def encode_with_transformers(self, tokens):
tokens = tokens.to(memory_management.get_torch_device())
device = memory_management.get_torch_device()
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
self.text_encoder.shared.to(device=device, dtype=dtype)
tokens = tokens.to(device)
self.text_encoder.shared.to(device=device, dtype=torch.float32)
z = self.text_encoder(
input_ids=tokens,