From 862c7a589e43935fa9271bb05399ca003420cbf9 Mon Sep 17 00:00:00 2001 From: altoiddealer Date: Sun, 13 Oct 2024 07:29:02 -0400 Subject: [PATCH] API Improvements: Modules Change AND Restore override_settings (#2027) * Improve API modules change * Restore override_settings and make it work * Simplify some memory management --- modules/api/api.py | 17 ++++++++----- modules/processing.py | 50 ++++++++++++++++++++++++++++++++++--- modules_forge/main_entry.py | 44 +++++++++++++++++++++++--------- 3 files changed, 89 insertions(+), 22 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 1f62b453..4ba9cad6 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -688,20 +688,25 @@ class Api: if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") - refresh_memory = False + memory_changes = {} memory_keys = ['forge_inference_memory', 'forge_async_loading', 'forge_pin_shared_memory'] for k, v in req.items(): - shared.opts.set(k, v, is_api=True) + # options for memory/modules are set in their dedicated functions if k in memory_keys: - refresh_memory = True + mem_key = k[len('forge_'):] # remove 'forge_' prefix + memory_changes[mem_key] = v + elif k == 'forge_additional_modules': + main_entry.modules_change(v, refresh_params=False) # refresh_model_loading_parameters() --- applied in checkpoint_change() + # set all other options + else: + shared.opts.set(k, v, is_api=True) main_entry.checkpoint_change(checkpoint_name) # shared.opts.save(shared.config_filename) --- applied in checkpoint_change() - if refresh_memory: - model_memory = main_entry.total_vram - shared.opts.forge_inference_memory - main_entry.refresh_memory_management_settings(model_memory, shared.opts.forge_async_loading, shared.opts.forge_pin_shared_memory) + if memory_changes: + main_entry.refresh_memory_management_settings(**memory_changes) return diff --git a/modules/processing.py b/modules/processing.py index 42831196..5f802d58 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -32,6 +32,7 @@ from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType from modules.sd_models import apply_token_merging, forge_model_reload from modules_forge.utils import apply_circular_forge +from modules_forge import main_entry from backend import memory_management @@ -810,11 +811,52 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.before_process(p) - # backwards compatibility, fix sampler and scheduler if invalid - sd_samplers.fix_p_invalid_sampler_and_scheduler(p) + stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data} - with profiling.Profiler(): - res = process_images_inner(p) + try: + # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint + # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards + if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: + p.override_settings.pop('sd_model_checkpoint', None) + + temp_memory_changes = {} + memory_keys = ['forge_inference_memory', 'forge_async_loading', 'forge_pin_shared_memory'] + + for k, v in p.override_settings.items(): + # options for memory/modules/checkpoints are set in their dedicated functions + if k in memory_keys: + mem_k = k[len('forge_'):] # remove 'forge_' prefix + temp_memory_changes[mem_k] = v + elif k == 'forge_additional_modules': + main_entry.modules_change(v, refresh_params=False) + elif k == 'sd_model_checkpoint': + main_entry.checkpoint_change(v) + # set all other options + else: + opts.set(k, v, is_api=True, run_callbacks=False) + + if temp_memory_changes: + main_entry.refresh_memory_management_settings(**temp_memory_changes) + + # backwards compatibility, fix sampler and scheduler if invalid + sd_samplers.fix_p_invalid_sampler_and_scheduler(p) + + with profiling.Profiler(): + res = process_images_inner(p) + + finally: + # restore opts to original state + if p.override_settings_restore_afterwards: + for k, v in stored_opts.items(): + if k == 'forge_additional_modules': + main_entry.modules_change(v, refresh_params=False) + elif k == 'sd_model_checkpoint': + main_entry.checkpoint_change(v) + else: + setattr(opts, k, v) + + if temp_memory_changes: + main_entry.refresh_memory_management_settings() # applies the set options by default return res diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index e9ea61d5..c8f4373d 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -112,16 +112,17 @@ def make_checkpoint_manager_ui(): mem_comps = [ui_forge_inference_memory, ui_forge_async_loading, ui_forge_pin_shared_memory] - ui_forge_inference_memory.change(refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False) - ui_forge_async_loading.change(refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False) - ui_forge_pin_shared_memory.change(refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False) - Context.root_block.load(refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False) + ui_forge_inference_memory.change(ui_refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False) + ui_forge_async_loading.change(ui_refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False) + ui_forge_pin_shared_memory.change(ui_refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False) + + Context.root_block.load(ui_refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False) ui_clip_skip = gr.Slider(label="Clip skip", value=lambda: shared.opts.CLIP_stop_at_last_layers, **{"minimum": 1, "maximum": 12, "step": 1}) bind_to_opts(ui_clip_skip, 'CLIP_stop_at_last_layers', save=False) ui_checkpoint.change(checkpoint_change, inputs=[ui_checkpoint], show_progress=False) - ui_vae.change(vae_change, inputs=[ui_vae], queue=False, show_progress=False) + ui_vae.change(modules_change, inputs=[ui_vae], queue=False, show_progress=False) return @@ -163,15 +164,32 @@ def refresh_models(): return ckpt_list, module_list.keys() -def refresh_memory_management_settings(model_memory, async_loading, pin_shared_memory): - inference_memory = total_vram - model_memory +def ui_refresh_memory_management_settings(model_memory, async_loading, pin_shared_memory): + """ Passes precalculated 'model_memory' from "GPU Weights" UI slider (skip redundant calculation) """ + refresh_memory_management_settings( + async_loading=async_loading, + pin_shared_memory=pin_shared_memory, + model_memory=model_memory # Use model_memory directly from UI slider value + ) + +def refresh_memory_management_settings(async_loading=None, inference_memory=None, pin_shared_memory=None, model_memory=None): + # Fallback to defaults if values are not passed + async_loading = async_loading if async_loading is not None else shared.opts.forge_async_loading + inference_memory = inference_memory if inference_memory is not None else shared.opts.forge_inference_memory + pin_shared_memory = pin_shared_memory if pin_shared_memory is not None else shared.opts.forge_pin_shared_memory + + # If model_memory is provided, calculate inference memory accordingly, otherwise use inference_memory directly + if model_memory is None: + model_memory = total_vram - inference_memory + else: + inference_memory = total_vram - model_memory shared.opts.set('forge_async_loading', async_loading) shared.opts.set('forge_inference_memory', inference_memory) shared.opts.set('forge_pin_shared_memory', pin_shared_memory) stream.stream_activated = async_loading == 'Async' - memory_management.current_inference_memory = inference_memory * 1024 * 1024 + memory_management.current_inference_memory = inference_memory * 1024 * 1024 # Convert MB to bytes memory_management.PIN_SHARED_MEMORY = pin_shared_memory == 'Shared' log_dict = dict( @@ -221,15 +239,16 @@ def refresh_model_loading_parameters(): return -def checkpoint_change(ckpt_name): +def checkpoint_change(ckpt_name, refresh_params=True): shared.opts.set('sd_model_checkpoint', ckpt_name) shared.opts.save(shared.config_filename) - refresh_model_loading_parameters() + if refresh_params: + refresh_model_loading_parameters() return -def vae_change(module_names): +def modules_change(module_names, refresh_params=True): modules = [] for n in module_names: @@ -239,7 +258,8 @@ def vae_change(module_names): shared.opts.set('forge_additional_modules', modules) shared.opts.save(shared.config_filename) - refresh_model_loading_parameters() + if refresh_params: + refresh_model_loading_parameters() return