rename state_dict method to make it clear

This commit is contained in:
layerdiffusion
2024-07-30 08:05:36 -06:00
parent 9cb69baf9f
commit 9a48c9eff3
2 changed files with 13 additions and 14 deletions

View File

@@ -13,6 +13,17 @@ class StateDictItem:
return StateDictItem(self.key, t, advanced_indexing=advanced_indexing)
def split_state_dict_with_prefix(sd, prefix):
vae_sd = {}
for k, v in list(sd.items()):
if k.startswith(prefix):
vae_sd[k] = StateDictItem(k[len(prefix):], v)
del sd[k]
return vae_sd
def shrink_last_key(t):
ts = t.split('.')
del ts[-1]

View File

@@ -1,30 +1,18 @@
from diffusers import AutoencoderKL
from backend.vae.configs.guess import guess_vae_config
from backend.state_dict import StateDictItem, compile_state_dict
from backend.state_dict import split_state_dict_with_prefix, compile_state_dict
from backend.operations import using_forge_operations
from backend.attention import AttentionProcessorForge
from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
def convert_vae_state_dict(sd):
vae_sd = {}
prefix = "first_stage_model."
for k, v in list(sd.items()):
if k.startswith(prefix):
vae_sd[k] = StateDictItem(k[len(prefix):], v)
del sd[k]
return vae_sd
def load_vae_from_state_dict(state_dict):
config = guess_vae_config(state_dict)
with using_forge_operations():
model = AutoencoderKL(**config)
vae_state_dict = convert_vae_state_dict(state_dict)
vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.")
vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config)
vae_state_dict, mapping = compile_state_dict(vae_state_dict)
model.load_state_dict(vae_state_dict, strict=True)