mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
rename state_dict method to make it clear
This commit is contained in:
@@ -13,6 +13,17 @@ class StateDictItem:
|
|||||||
return StateDictItem(self.key, t, advanced_indexing=advanced_indexing)
|
return StateDictItem(self.key, t, advanced_indexing=advanced_indexing)
|
||||||
|
|
||||||
|
|
||||||
|
def split_state_dict_with_prefix(sd, prefix):
|
||||||
|
vae_sd = {}
|
||||||
|
|
||||||
|
for k, v in list(sd.items()):
|
||||||
|
if k.startswith(prefix):
|
||||||
|
vae_sd[k] = StateDictItem(k[len(prefix):], v)
|
||||||
|
del sd[k]
|
||||||
|
|
||||||
|
return vae_sd
|
||||||
|
|
||||||
|
|
||||||
def shrink_last_key(t):
|
def shrink_last_key(t):
|
||||||
ts = t.split('.')
|
ts = t.split('.')
|
||||||
del ts[-1]
|
del ts[-1]
|
||||||
|
|||||||
@@ -1,30 +1,18 @@
|
|||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from backend.vae.configs.guess import guess_vae_config
|
from backend.vae.configs.guess import guess_vae_config
|
||||||
from backend.state_dict import StateDictItem, compile_state_dict
|
from backend.state_dict import split_state_dict_with_prefix, compile_state_dict
|
||||||
from backend.operations import using_forge_operations
|
from backend.operations import using_forge_operations
|
||||||
from backend.attention import AttentionProcessorForge
|
from backend.attention import AttentionProcessorForge
|
||||||
from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
|
from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def convert_vae_state_dict(sd):
|
|
||||||
vae_sd = {}
|
|
||||||
prefix = "first_stage_model."
|
|
||||||
|
|
||||||
for k, v in list(sd.items()):
|
|
||||||
if k.startswith(prefix):
|
|
||||||
vae_sd[k] = StateDictItem(k[len(prefix):], v)
|
|
||||||
del sd[k]
|
|
||||||
|
|
||||||
return vae_sd
|
|
||||||
|
|
||||||
|
|
||||||
def load_vae_from_state_dict(state_dict):
|
def load_vae_from_state_dict(state_dict):
|
||||||
config = guess_vae_config(state_dict)
|
config = guess_vae_config(state_dict)
|
||||||
|
|
||||||
with using_forge_operations():
|
with using_forge_operations():
|
||||||
model = AutoencoderKL(**config)
|
model = AutoencoderKL(**config)
|
||||||
|
|
||||||
vae_state_dict = convert_vae_state_dict(state_dict)
|
vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.")
|
||||||
vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config)
|
vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config)
|
||||||
vae_state_dict, mapping = compile_state_dict(vae_state_dict)
|
vae_state_dict, mapping = compile_state_dict(vae_state_dict)
|
||||||
model.load_state_dict(vae_state_dict, strict=True)
|
model.load_state_dict(vae_state_dict, strict=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user