diff --git a/backend/state_dict.py b/backend/state_dict.py index 4bb6c8d8..02634b6d 100644 --- a/backend/state_dict.py +++ b/backend/state_dict.py @@ -13,6 +13,17 @@ class StateDictItem: return StateDictItem(self.key, t, advanced_indexing=advanced_indexing) +def split_state_dict_with_prefix(sd, prefix): + vae_sd = {} + + for k, v in list(sd.items()): + if k.startswith(prefix): + vae_sd[k] = StateDictItem(k[len(prefix):], v) + del sd[k] + + return vae_sd + + def shrink_last_key(t): ts = t.split('.') del ts[-1] diff --git a/backend/vae/loader.py b/backend/vae/loader.py index 035626a0..6d858154 100644 --- a/backend/vae/loader.py +++ b/backend/vae/loader.py @@ -1,30 +1,18 @@ from diffusers import AutoencoderKL from backend.vae.configs.guess import guess_vae_config -from backend.state_dict import StateDictItem, compile_state_dict +from backend.state_dict import split_state_dict_with_prefix, compile_state_dict from backend.operations import using_forge_operations from backend.attention import AttentionProcessorForge from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint -def convert_vae_state_dict(sd): - vae_sd = {} - prefix = "first_stage_model." - - for k, v in list(sd.items()): - if k.startswith(prefix): - vae_sd[k] = StateDictItem(k[len(prefix):], v) - del sd[k] - - return vae_sd - - def load_vae_from_state_dict(state_dict): config = guess_vae_config(state_dict) with using_forge_operations(): model = AutoencoderKL(**config) - vae_state_dict = convert_vae_state_dict(state_dict) + vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.") vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config) vae_state_dict, mapping = compile_state_dict(vae_state_dict) model.load_state_dict(vae_state_dict, strict=True)