improve vae key mapping

This commit is contained in:
layerdiffusion
2024-07-30 09:23:58 -06:00
parent 3289ccb53f
commit 40dd61ba6c
4 changed files with 23 additions and 11 deletions

View File

@@ -7,6 +7,11 @@ from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
class BaseVAE(AutoencoderKL):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state_dict_mapping = {}
def encode(self, x, regulation=None, mode=False):
latent_dist = super().encode(x).latent_dist
if mode:
@@ -31,5 +36,6 @@ def load_vae_from_state_dict(state_dict):
vae_state_dict, mapping = compile_state_dict(vae_state_dict)
model.load_state_dict(vae_state_dict, strict=True)
model.set_attn_processor(AttentionProcessorForge())
model.state_dict_mapping = mapping
return model, mapping
return model