This commit is contained in:
layerdiffusion
2024-08-06 20:31:13 -07:00
parent b57573c8da
commit 1be7f9ea6f

View File

@@ -6,6 +6,7 @@ from modules import paths, shared, devices, script_callbacks, sd_models, extra_n
import glob
from copy import deepcopy
from backend.utils import load_torch_file
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
@@ -186,9 +187,7 @@ def resolve_vae(checkpoint_file) -> VaeResolution:
def load_vae_dict(filename, map_location):
vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict_1
return load_torch_file(filename)
def load_vae(model, vae_file=None, vae_source="from unknown source"):
@@ -236,7 +235,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.load_state_dict(vae_dict_1, strict=False)
def clear_loaded_vae():