Load Model only when click Generate

#964
This commit is contained in:
lllyasviel
2024-08-08 14:51:13 -07:00
committed by GitHub
parent ce3f0f86b4
commit 6921420b3f
8 changed files with 76 additions and 140 deletions

View File

@@ -30,7 +30,7 @@ import modules.sd_vae as sd_vae
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
from modules.sd_models import apply_token_merging
from modules.sd_models import apply_token_merging, forge_model_reload
from modules_forge.utils import apply_circular_forge
@@ -774,41 +774,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
def process_images(p: StableDiffusionProcessing) -> Processed:
forge_model_reload()
if p.scripts is not None:
p.scripts.before_process(p)
stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
# backwards compatibility, fix sampler and scheduler if invalid
sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
# and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
sd_models.reload_model_weights()
for k, v in p.override_settings.items():
opts.set(k, v, is_api=True, run_callbacks=False)
if k == 'sd_model_checkpoint':
sd_models.reload_model_weights()
if k == 'sd_vae':
sd_vae.reload_vae_weights()
# backwards compatibility, fix sampler and scheduler if invalid
sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
with profiling.Profiler():
res = process_images_inner(p)
finally:
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
if k == 'sd_vae':
sd_vae.reload_vae_weights()
with profiling.Profiler():
res = process_images_inner(p)
return res