Load Model only when click Generate

#964
This commit is contained in:
lllyasviel
2024-08-08 14:51:13 -07:00
committed by GitHub
parent ce3f0f86b4
commit 6921420b3f
8 changed files with 76 additions and 140 deletions

View File

@@ -187,87 +187,24 @@ def resolve_vae(checkpoint_file) -> VaeResolution:
def load_vae_dict(filename, map_location):
return load_torch_file(filename)
pass
def load_vae(model, vae_file=None, vae_source="from unknown source"):
global vae_dict, base_vae, loaded_vae_file
# save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
if vae_file:
if cache_enabled and vae_file in checkpoints_loaded:
# use vae checkpoint cache
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
store_base_vae(model)
_load_vae_dict(model, checkpoints_loaded[vae_file])
else:
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
# cache newly loaded vae
checkpoints_loaded[vae_file] = vae_dict_1.copy()
# clean up cache if limit is reached
if cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
elif loaded_vae_file:
restore_base_vae(model)
loaded_vae_file = vae_file
model.base_vae = base_vae
model.loaded_vae_file = loaded_vae_file
raise NotImplementedError('Forge does not use this!')
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1, strict=False)
pass
def clear_loaded_vae():
global loaded_vae_file
loaded_vae_file = None
pass
unspecified = object()
def reload_vae_weights(sd_model=None, vae_file=unspecified):
if not sd_model:
sd_model = shared.sd_model
checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_file = checkpoint_info.filename
if vae_file == unspecified:
vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
else:
vae_source = "from function argument"
if loaded_vae_file == vae_file:
return
# sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file, vae_source)
# sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
print("VAE weights loaded.")
return sd_model
raise NotImplementedError('Forge does not use this!')