improve vae key mapping

This commit is contained in:
layerdiffusion
2024-07-30 09:23:58 -06:00
parent 3289ccb53f
commit 40dd61ba6c
4 changed files with 23 additions and 11 deletions

View File

@@ -3,6 +3,7 @@ import collections
from dataclasses import dataclass
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
from backend.state_dict import map_state_dict
import glob
from copy import deepcopy
@@ -236,7 +237,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
sd_mapped = map_state_dict(vae_dict_1, model.first_stage_model.state_dict_mapping)
model.first_stage_model.load_state_dict(sd_mapped)
def clear_loaded_vae():