From a05a06b33754aff57e8cbbe823e7fd6336571568 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 8 Aug 2024 01:53:03 -0700 Subject: [PATCH] make results more consistent to A1111 --- backend/loader.py | 6 ++---- backend/text_processing/classic_engine.py | 6 +++++- backend/text_processing/t5_engine.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/backend/loader.py b/backend/loader.py index d5ddc10c..4d928c38 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -54,7 +54,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p to_args = dict(device=memory_management.text_encoder_device(), dtype=memory_management.text_encoder_dtype()) with modeling_utils.no_init_weights(): - with using_forge_operations(**to_args): + with using_forge_operations(**to_args, manual_cast_enabled=True): model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True).to(**to_args) load_state_dict(model, state_dict, ignore_errors=[ @@ -70,14 +70,12 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p dtype = memory_management.text_encoder_dtype() sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype - need_cast = False if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: dtype = sd_dtype - need_cast = True with modeling_utils.no_init_weights(): - with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast): + with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True): model = IntegratedT5(config) load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight']) diff --git a/backend/text_processing/classic_engine.py b/backend/text_processing/classic_engine.py index 14a27ca3..bd2cf674 100644 --- a/backend/text_processing/classic_engine.py +++ b/backend/text_processing/classic_engine.py @@ -120,8 +120,12 @@ class ClassicTextProcessingEngine: return tokenized def encode_with_transformers(self, tokens): - target_device = self.text_encoder.transformer.text_model.embeddings.token_embedding.weight.device + target_device = memory_management.text_encoder_device() + self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device) + self.text_encoder.transformer.text_model.embeddings.position_embedding = self.text_encoder.transformer.text_model.embeddings.position_embedding.to(dtype=torch.float32) + self.text_encoder.transformer.text_model.embeddings.token_embedding = self.text_encoder.transformer.text_model.embeddings.token_embedding.to(dtype=torch.float32) + tokens = tokens.to(target_device) outputs = self.text_encoder.transformer(tokens, output_hidden_states=True) diff --git a/backend/text_processing/t5_engine.py b/backend/text_processing/t5_engine.py index 1d7065c6..49b9e89e 100644 --- a/backend/text_processing/t5_engine.py +++ b/backend/text_processing/t5_engine.py @@ -58,7 +58,7 @@ class T5TextProcessingEngine: return tokenized def encode_with_transformers(self, tokens): - device = memory_management.get_torch_device() + device = memory_management.text_encoder_device() tokens = tokens.to(device) self.text_encoder.shared.to(device=device, dtype=torch.float32)