improve vae key mapping

This commit is contained in:
layerdiffusion
2024-07-30 09:23:58 -06:00
parent 3289ccb53f
commit 40dd61ba6c
4 changed files with 23 additions and 11 deletions

View File

@@ -24,16 +24,20 @@ def split_state_dict_with_prefix(sd, prefix):
return vae_sd return vae_sd
def shrink_last_key(t):
ts = t.split('.')
del ts[-1]
return '.'.join(ts)
def compile_state_dict(state_dict): def compile_state_dict(state_dict):
sd = {} sd = {}
mapping = {} mapping = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
sd[k] = v.value sd[k] = v.value
mapping[shrink_last_key(v.key)] = shrink_last_key(k) mapping[v.key] = (k, v.advanced_indexing)
return sd, mapping return sd, mapping
def map_state_dict(sd, mapping):
new_sd = {}
for k, v in sd.items():
k, indexing = mapping.get(k, (k, None))
if indexing is not None:
v = v[indexing]
new_sd[k] = v
return new_sd

View File

@@ -7,6 +7,11 @@ from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
class BaseVAE(AutoencoderKL): class BaseVAE(AutoencoderKL):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state_dict_mapping = {}
def encode(self, x, regulation=None, mode=False): def encode(self, x, regulation=None, mode=False):
latent_dist = super().encode(x).latent_dist latent_dist = super().encode(x).latent_dist
if mode: if mode:
@@ -31,5 +36,6 @@ def load_vae_from_state_dict(state_dict):
vae_state_dict, mapping = compile_state_dict(vae_state_dict) vae_state_dict, mapping = compile_state_dict(vae_state_dict)
model.load_state_dict(vae_state_dict, strict=True) model.load_state_dict(vae_state_dict, strict=True)
model.set_attn_processor(AttentionProcessorForge()) model.set_attn_processor(AttentionProcessorForge())
model.state_dict_mapping = mapping
return model, mapping return model

View File

@@ -3,6 +3,7 @@ import collections
from dataclasses import dataclass from dataclasses import dataclass
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
from backend.state_dict import map_state_dict
import glob import glob
from copy import deepcopy from copy import deepcopy
@@ -236,7 +237,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
# don't call this from outside # don't call this from outside
def _load_vae_dict(model, vae_dict_1): def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1) sd_mapped = map_state_dict(vae_dict_1, model.first_stage_model.state_dict_mapping)
model.first_stage_model.load_state_dict(sd_mapped)
def clear_loaded_vae(): def clear_loaded_vae():

View File

@@ -106,8 +106,8 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
vae, mapping = load_vae_from_state_dict(sd) vae = load_vae_from_state_dict(sd)
vae = VAE(model=vae, mapping=mapping) vae = VAE(model=vae, mapping=vae.state_dict_mapping)
if output_clip: if output_clip:
w = WeightsLoader() w = WeightsLoader()