mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-26 09:43:56 +00:00
improve vae key mapping
This commit is contained in:
@@ -24,16 +24,20 @@ def split_state_dict_with_prefix(sd, prefix):
|
|||||||
return vae_sd
|
return vae_sd
|
||||||
|
|
||||||
|
|
||||||
def shrink_last_key(t):
|
|
||||||
ts = t.split('.')
|
|
||||||
del ts[-1]
|
|
||||||
return '.'.join(ts)
|
|
||||||
|
|
||||||
|
|
||||||
def compile_state_dict(state_dict):
|
def compile_state_dict(state_dict):
|
||||||
sd = {}
|
sd = {}
|
||||||
mapping = {}
|
mapping = {}
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
sd[k] = v.value
|
sd[k] = v.value
|
||||||
mapping[shrink_last_key(v.key)] = shrink_last_key(k)
|
mapping[v.key] = (k, v.advanced_indexing)
|
||||||
return sd, mapping
|
return sd, mapping
|
||||||
|
|
||||||
|
|
||||||
|
def map_state_dict(sd, mapping):
|
||||||
|
new_sd = {}
|
||||||
|
for k, v in sd.items():
|
||||||
|
k, indexing = mapping.get(k, (k, None))
|
||||||
|
if indexing is not None:
|
||||||
|
v = v[indexing]
|
||||||
|
new_sd[k] = v
|
||||||
|
return new_sd
|
||||||
|
|||||||
@@ -7,6 +7,11 @@ from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
|
|||||||
|
|
||||||
|
|
||||||
class BaseVAE(AutoencoderKL):
|
class BaseVAE(AutoencoderKL):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.state_dict_mapping = {}
|
||||||
|
|
||||||
def encode(self, x, regulation=None, mode=False):
|
def encode(self, x, regulation=None, mode=False):
|
||||||
latent_dist = super().encode(x).latent_dist
|
latent_dist = super().encode(x).latent_dist
|
||||||
if mode:
|
if mode:
|
||||||
@@ -31,5 +36,6 @@ def load_vae_from_state_dict(state_dict):
|
|||||||
vae_state_dict, mapping = compile_state_dict(vae_state_dict)
|
vae_state_dict, mapping = compile_state_dict(vae_state_dict)
|
||||||
model.load_state_dict(vae_state_dict, strict=True)
|
model.load_state_dict(vae_state_dict, strict=True)
|
||||||
model.set_attn_processor(AttentionProcessorForge())
|
model.set_attn_processor(AttentionProcessorForge())
|
||||||
|
model.state_dict_mapping = mapping
|
||||||
|
|
||||||
return model, mapping
|
return model
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import collections
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
|
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
|
||||||
|
from backend.state_dict import map_state_dict
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@@ -236,7 +237,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
|||||||
|
|
||||||
# don't call this from outside
|
# don't call this from outside
|
||||||
def _load_vae_dict(model, vae_dict_1):
|
def _load_vae_dict(model, vae_dict_1):
|
||||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
sd_mapped = map_state_dict(vae_dict_1, model.first_stage_model.state_dict_mapping)
|
||||||
|
model.first_stage_model.load_state_dict(sd_mapped)
|
||||||
|
|
||||||
|
|
||||||
def clear_loaded_vae():
|
def clear_loaded_vae():
|
||||||
|
|||||||
@@ -106,8 +106,8 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
vae, mapping = load_vae_from_state_dict(sd)
|
vae = load_vae_from_state_dict(sd)
|
||||||
vae = VAE(model=vae, mapping=mapping)
|
vae = VAE(model=vae, mapping=vae.state_dict_mapping)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
|
|||||||
Reference in New Issue
Block a user