diff --git a/modules/processing.py b/modules/processing.py index 5f802d58..64b62957 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -794,8 +794,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter need_global_unload = False - -def process_images(p: StableDiffusionProcessing) -> Processed: +def manage_model_and_prompt_cache(p: StableDiffusionProcessing): global need_global_unload p.sd_model, just_reloaded = forge_model_reload() @@ -808,9 +807,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: need_global_unload = False - if p.scripts is not None: - p.scripts.before_process(p) +def process_images(p: StableDiffusionProcessing) -> Processed: + """applies settings overrides (if any) before processing images, then restores settings as applicable.""" stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data} try: @@ -828,7 +827,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: mem_k = k[len('forge_'):] # remove 'forge_' prefix temp_memory_changes[mem_k] = v elif k == 'forge_additional_modules': - main_entry.modules_change(v, refresh_params=False) + main_entry.modules_change(v) elif k == 'sd_model_checkpoint': main_entry.checkpoint_change(v) # set all other options @@ -838,6 +837,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if temp_memory_changes: main_entry.refresh_memory_management_settings(**temp_memory_changes) + # load/reload model and manage prompt cache as needed + manage_model_and_prompt_cache(p) + + if p.scripts is not None: + p.scripts.before_process(p) + # backwards compatibility, fix sampler and scheduler if invalid sd_samplers.fix_p_invalid_sampler_and_scheduler(p) @@ -849,7 +854,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): if k == 'forge_additional_modules': - main_entry.modules_change(v, refresh_params=False) + main_entry.modules_change(v) elif k == 'sd_model_checkpoint': main_entry.checkpoint_change(v) else: