diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 6f17c5b6..05c1cb1b 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -6,6 +6,7 @@ from modules import paths, shared, devices, script_callbacks, sd_models, extra_n import glob from copy import deepcopy +from backend.utils import load_torch_file vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) @@ -186,9 +187,7 @@ def resolve_vae(checkpoint_file) -> VaeResolution: def load_vae_dict(filename, map_location): - vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) - vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} - return vae_dict_1 + return load_torch_file(filename) def load_vae(model, vae_file=None, vae_source="from unknown source"): @@ -236,7 +235,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): # don't call this from outside def _load_vae_dict(model, vae_dict_1): - model.first_stage_model.load_state_dict(vae_dict_1) + model.first_stage_model.load_state_dict(vae_dict_1, strict=False) def clear_loaded_vae():