mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-07 05:59:48 +00:00
rename state_dict method to make it clear
This commit is contained in:
@@ -13,6 +13,17 @@ class StateDictItem:
|
||||
return StateDictItem(self.key, t, advanced_indexing=advanced_indexing)
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(sd, prefix):
|
||||
vae_sd = {}
|
||||
|
||||
for k, v in list(sd.items()):
|
||||
if k.startswith(prefix):
|
||||
vae_sd[k] = StateDictItem(k[len(prefix):], v)
|
||||
del sd[k]
|
||||
|
||||
return vae_sd
|
||||
|
||||
|
||||
def shrink_last_key(t):
|
||||
ts = t.split('.')
|
||||
del ts[-1]
|
||||
|
||||
@@ -1,30 +1,18 @@
|
||||
from diffusers import AutoencoderKL
|
||||
from backend.vae.configs.guess import guess_vae_config
|
||||
from backend.state_dict import StateDictItem, compile_state_dict
|
||||
from backend.state_dict import split_state_dict_with_prefix, compile_state_dict
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.attention import AttentionProcessorForge
|
||||
from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
|
||||
|
||||
|
||||
def convert_vae_state_dict(sd):
|
||||
vae_sd = {}
|
||||
prefix = "first_stage_model."
|
||||
|
||||
for k, v in list(sd.items()):
|
||||
if k.startswith(prefix):
|
||||
vae_sd[k] = StateDictItem(k[len(prefix):], v)
|
||||
del sd[k]
|
||||
|
||||
return vae_sd
|
||||
|
||||
|
||||
def load_vae_from_state_dict(state_dict):
|
||||
config = guess_vae_config(state_dict)
|
||||
|
||||
with using_forge_operations():
|
||||
model = AutoencoderKL(**config)
|
||||
|
||||
vae_state_dict = convert_vae_state_dict(state_dict)
|
||||
vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.")
|
||||
vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config)
|
||||
vae_state_dict, mapping = compile_state_dict(vae_state_dict)
|
||||
model.load_state_dict(vae_state_dict, strict=True)
|
||||
|
||||
Reference in New Issue
Block a user