diff --git a/backend/state_dict.py b/backend/state_dict.py index 02634b6d..b9e662fe 100644 --- a/backend/state_dict.py +++ b/backend/state_dict.py @@ -24,16 +24,20 @@ def split_state_dict_with_prefix(sd, prefix): return vae_sd -def shrink_last_key(t): - ts = t.split('.') - del ts[-1] - return '.'.join(ts) - - def compile_state_dict(state_dict): sd = {} mapping = {} for k, v in state_dict.items(): sd[k] = v.value - mapping[shrink_last_key(v.key)] = shrink_last_key(k) + mapping[v.key] = (k, v.advanced_indexing) return sd, mapping + + +def map_state_dict(sd, mapping): + new_sd = {} + for k, v in sd.items(): + k, indexing = mapping.get(k, (k, None)) + if indexing is not None: + v = v[indexing] + new_sd[k] = v + return new_sd diff --git a/backend/vae/loader.py b/backend/vae/loader.py index e904ea73..1694dbe7 100644 --- a/backend/vae/loader.py +++ b/backend/vae/loader.py @@ -7,6 +7,11 @@ from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint class BaseVAE(AutoencoderKL): + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.state_dict_mapping = {} + def encode(self, x, regulation=None, mode=False): latent_dist = super().encode(x).latent_dist if mode: @@ -31,5 +36,6 @@ def load_vae_from_state_dict(state_dict): vae_state_dict, mapping = compile_state_dict(vae_state_dict) model.load_state_dict(vae_state_dict, strict=True) model.set_attn_processor(AttentionProcessorForge()) + model.state_dict_mapping = mapping - return model, mapping + return model diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 62fd6524..e7ef5a77 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,6 +3,7 @@ import collections from dataclasses import dataclass from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes +from backend.state_dict import map_state_dict import glob from copy import deepcopy @@ -236,7 +237,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): # don't call this from outside def _load_vae_dict(model, vae_dict_1): - model.first_stage_model.load_state_dict(vae_dict_1) + sd_mapped = map_state_dict(vae_dict_1, model.first_stage_model.state_dict_mapping) + model.first_stage_model.load_state_dict(sd_mapped) def clear_loaded_vae(): diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index bf342b4e..a77c0ef7 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -106,8 +106,8 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae, mapping = load_vae_from_state_dict(sd) - vae = VAE(model=vae, mapping=mapping) + vae = load_vae_from_state_dict(sd) + vae = VAE(model=vae, mapping=vae.state_dict_mapping) if output_clip: w = WeightsLoader()