From 1be7f9ea6f096c4c55d9c9907a5fde3532630aba Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Tue, 6 Aug 2024 20:31:13 -0700 Subject: [PATCH] fix vae --- modules/sd_vae.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 6f17c5b6..05c1cb1b 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -6,6 +6,7 @@ from modules import paths, shared, devices, script_callbacks, sd_models, extra_n import glob from copy import deepcopy +from backend.utils import load_torch_file vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) @@ -186,9 +187,7 @@ def resolve_vae(checkpoint_file) -> VaeResolution: def load_vae_dict(filename, map_location): - vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) - vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} - return vae_dict_1 + return load_torch_file(filename) def load_vae(model, vae_file=None, vae_source="from unknown source"): @@ -236,7 +235,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): # don't call this from outside def _load_vae_dict(model, vae_dict_1): - model.first_stage_model.load_state_dict(vae_dict_1) + model.first_stage_model.load_state_dict(vae_dict_1, strict=False) def clear_loaded_vae():