make results more consistent to A1111

This commit is contained in:
layerdiffusion
2024-08-08 01:53:03 -07:00
parent e396307e9d
commit a05a06b337
3 changed files with 8 additions and 6 deletions

View File

@@ -54,7 +54,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
to_args = dict(device=memory_management.text_encoder_device(), dtype=memory_management.text_encoder_dtype())
with modeling_utils.no_init_weights():
with using_forge_operations(**to_args):
with using_forge_operations(**to_args, manual_cast_enabled=True):
model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True).to(**to_args)
load_state_dict(model, state_dict, ignore_errors=[
@@ -70,14 +70,12 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
dtype = memory_management.text_encoder_dtype()
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
need_cast = False
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
dtype = sd_dtype
need_cast = True
with modeling_utils.no_init_weights():
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True):
model = IntegratedT5(config)
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])

View File

@@ -120,8 +120,12 @@ class ClassicTextProcessingEngine:
return tokenized
def encode_with_transformers(self, tokens):
target_device = self.text_encoder.transformer.text_model.embeddings.token_embedding.weight.device
target_device = memory_management.text_encoder_device()
self.text_encoder.transformer.text_model.embeddings.position_ids = self.text_encoder.transformer.text_model.embeddings.position_ids.to(device=target_device)
self.text_encoder.transformer.text_model.embeddings.position_embedding = self.text_encoder.transformer.text_model.embeddings.position_embedding.to(dtype=torch.float32)
self.text_encoder.transformer.text_model.embeddings.token_embedding = self.text_encoder.transformer.text_model.embeddings.token_embedding.to(dtype=torch.float32)
tokens = tokens.to(target_device)
outputs = self.text_encoder.transformer(tokens, output_hidden_states=True)

View File

@@ -58,7 +58,7 @@ class T5TextProcessingEngine:
return tokenized
def encode_with_transformers(self, tokens):
device = memory_management.get_torch_device()
device = memory_management.text_encoder_device()
tokens = tokens.to(device)
self.text_encoder.shared.to(device=device, dtype=torch.float32)