API Improvements: Modules Change AND Restore override_settings (#2027)

* Improve API modules change
* Restore override_settings and make it work
* Simplify some memory management
This commit is contained in:
altoiddealer
2024-10-13 07:29:02 -04:00
committed by GitHub
parent ae8187bf2d
commit 862c7a589e
3 changed files with 89 additions and 22 deletions

View File

@@ -112,16 +112,17 @@ def make_checkpoint_manager_ui():
mem_comps = [ui_forge_inference_memory, ui_forge_async_loading, ui_forge_pin_shared_memory]
ui_forge_inference_memory.change(refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False)
ui_forge_async_loading.change(refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False)
ui_forge_pin_shared_memory.change(refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False)
Context.root_block.load(refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False)
ui_forge_inference_memory.change(ui_refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False)
ui_forge_async_loading.change(ui_refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False)
ui_forge_pin_shared_memory.change(ui_refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False)
Context.root_block.load(ui_refresh_memory_management_settings, inputs=mem_comps, queue=False, show_progress=False)
ui_clip_skip = gr.Slider(label="Clip skip", value=lambda: shared.opts.CLIP_stop_at_last_layers, **{"minimum": 1, "maximum": 12, "step": 1})
bind_to_opts(ui_clip_skip, 'CLIP_stop_at_last_layers', save=False)
ui_checkpoint.change(checkpoint_change, inputs=[ui_checkpoint], show_progress=False)
ui_vae.change(vae_change, inputs=[ui_vae], queue=False, show_progress=False)
ui_vae.change(modules_change, inputs=[ui_vae], queue=False, show_progress=False)
return
@@ -163,15 +164,32 @@ def refresh_models():
return ckpt_list, module_list.keys()
def refresh_memory_management_settings(model_memory, async_loading, pin_shared_memory):
inference_memory = total_vram - model_memory
def ui_refresh_memory_management_settings(model_memory, async_loading, pin_shared_memory):
""" Passes precalculated 'model_memory' from "GPU Weights" UI slider (skip redundant calculation) """
refresh_memory_management_settings(
async_loading=async_loading,
pin_shared_memory=pin_shared_memory,
model_memory=model_memory # Use model_memory directly from UI slider value
)
def refresh_memory_management_settings(async_loading=None, inference_memory=None, pin_shared_memory=None, model_memory=None):
# Fallback to defaults if values are not passed
async_loading = async_loading if async_loading is not None else shared.opts.forge_async_loading
inference_memory = inference_memory if inference_memory is not None else shared.opts.forge_inference_memory
pin_shared_memory = pin_shared_memory if pin_shared_memory is not None else shared.opts.forge_pin_shared_memory
# If model_memory is provided, calculate inference memory accordingly, otherwise use inference_memory directly
if model_memory is None:
model_memory = total_vram - inference_memory
else:
inference_memory = total_vram - model_memory
shared.opts.set('forge_async_loading', async_loading)
shared.opts.set('forge_inference_memory', inference_memory)
shared.opts.set('forge_pin_shared_memory', pin_shared_memory)
stream.stream_activated = async_loading == 'Async'
memory_management.current_inference_memory = inference_memory * 1024 * 1024
memory_management.current_inference_memory = inference_memory * 1024 * 1024 # Convert MB to bytes
memory_management.PIN_SHARED_MEMORY = pin_shared_memory == 'Shared'
log_dict = dict(
@@ -221,15 +239,16 @@ def refresh_model_loading_parameters():
return
def checkpoint_change(ckpt_name):
def checkpoint_change(ckpt_name, refresh_params=True):
shared.opts.set('sd_model_checkpoint', ckpt_name)
shared.opts.save(shared.config_filename)
refresh_model_loading_parameters()
if refresh_params:
refresh_model_loading_parameters()
return
def vae_change(module_names):
def modules_change(module_names, refresh_params=True):
modules = []
for n in module_names:
@@ -239,7 +258,8 @@ def vae_change(module_names):
shared.opts.set('forge_additional_modules', modules)
shared.opts.save(shared.config_filename)
refresh_model_loading_parameters()
if refresh_params:
refresh_model_loading_parameters()
return