Files
stable-diffusion-webui-forge/backend/vae/loader.py
2024-07-29 22:26:32 -06:00

34 lines
1.1 KiB
Python

from diffusers import AutoencoderKL
from backend.vae.configs.guess import guess_vae_config
from backend.state_dict import StateDictItem, compile_state_dict
from backend.operations import using_forge_operations
from backend.attention import AttentionProcessorForge
from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
def convert_vae_state_dict(sd):
vae_sd = {}
prefix = "first_stage_model."
for k, v in list(sd.items()):
if k.startswith(prefix):
vae_sd[k] = StateDictItem(k[len(prefix):], v)
del sd[k]
return vae_sd
def load_vae_from_state_dict(state_dict):
config = guess_vae_config(state_dict)
with using_forge_operations():
model = AutoencoderKL(**config)
vae_state_dict = convert_vae_state_dict(state_dict)
vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config)
vae_state_dict, mapping = compile_state_dict(vae_state_dict)
model.load_state_dict(vae_state_dict, strict=True)
model.set_attn_processor(AttentionProcessorForge())
return model, mapping