Improve options management (#2078)

- `/sdapi/v1/options` GET now calls `get_config()` from **sysinfo** module, instead of from its own version of the function.

- Defined a new, flexible and more robust `set_config()` function in **sysinfo** module, which:
  - obsoletes redundant code
  - skips updating values that are unchanged
  - has flexible args for both API and UI use

- `/sdapi/v1/options` POST and `override_settings` now use the new `set_config()` function.  `set_config()` could possibly obsolete additional functions, but I'm not going to get into that just yet.

- Options for `forge_additional_modules` can now be provided either as the file path, or just the module name.

- Most importantly, `refresh_model_loading_parameters()` is now only called ONCE per request, and **only** if necessary.

- It is now much easier to call `shared.opts.save()` as needed
This commit is contained in:
altoiddealer
2024-10-16 06:21:54 -04:00
committed by GitHub
parent cce30d3340
commit 2c543719e3
4 changed files with 69 additions and 75 deletions

View File

@@ -19,7 +19,6 @@ from secrets import compare_digest
import modules.shared as shared
from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
from modules.api import models
from modules_forge import main_entry
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images, process_extra_images
from modules.textual_inversion.textual_inversion import create_embedding
@@ -675,42 +674,12 @@ class Api:
shared.state.skip()
def get_config(self):
options = {}
for key in shared.opts.data.keys():
metadata = shared.opts.data_labels.get(key)
if(metadata is not None):
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
else:
options.update({key: shared.opts.data.get(key, None)})
return options
from modules.sysinfo import get_config
return get_config()
def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found")
memory_changes = {}
memory_keys = ['forge_inference_memory', 'forge_async_loading', 'forge_pin_shared_memory']
for k, v in req.items():
# options for memory/modules are set in their dedicated functions
if k in memory_keys:
mem_key = k[len('forge_'):] # remove 'forge_' prefix
memory_changes[mem_key] = v
elif k == 'forge_additional_modules':
main_entry.modules_change(v, refresh_params=False) # refresh_model_loading_parameters() --- applied in checkpoint_change()
# set all other options
else:
shared.opts.set(k, v, is_api=True)
main_entry.checkpoint_change(checkpoint_name)
# shared.opts.save(shared.config_filename) --- applied in checkpoint_change()
if memory_changes:
main_entry.refresh_memory_management_settings(**memory_changes)
return
from modules.sysinfo import set_config
set_config(req)
def get_cmd_flags(self):
return vars(shared.cmd_opts)

View File

@@ -20,6 +20,7 @@ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infot
from modules.rng import slerp, get_noise_source_type # noqa: F401
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
from modules.shared import opts, cmd_opts, state
from modules.sysinfo import set_config
import modules.shared as shared
import modules.paths as paths
import modules.face_restoration
@@ -818,24 +819,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
temp_memory_changes = {}
memory_keys = ['forge_inference_memory', 'forge_async_loading', 'forge_pin_shared_memory']
for k, v in p.override_settings.items():
# options for memory/modules/checkpoints are set in their dedicated functions
if k in memory_keys:
mem_k = k[len('forge_'):] # remove 'forge_' prefix
temp_memory_changes[mem_k] = v
elif k == 'forge_additional_modules':
main_entry.modules_change(v)
elif k == 'sd_model_checkpoint':
main_entry.checkpoint_change(v)
# set all other options
else:
opts.set(k, v, is_api=True, run_callbacks=False)
if temp_memory_changes:
main_entry.refresh_memory_management_settings(**temp_memory_changes)
# apply any options overrides
set_config(p.override_settings, is_api=True, run_callbacks=False, save_config=False)
# load/reload model and manage prompt cache as needed
manage_model_and_prompt_cache(p)
@@ -850,18 +835,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
res = process_images_inner(p)
finally:
# restore opts to original state
# restore original options
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
if k == 'forge_additional_modules':
main_entry.modules_change(v)
elif k == 'sd_model_checkpoint':
main_entry.checkpoint_change(v)
else:
setattr(opts, k, v)
if temp_memory_changes:
main_entry.refresh_memory_management_settings() # applies the set options by default
set_config(stored_opts, save_config=False)
return res

View File

@@ -6,6 +6,7 @@ import platform
import hashlib
import re
from pathlib import Path
from typing import Any
from modules import paths_internal, timer, shared_cmd_options, errors, launch_utils
@@ -213,3 +214,44 @@ def get_config():
return json.load(f)
except Exception as e:
return str(e)
def set_config(req: dict[str, Any], is_api=False, run_callbacks=True, save_config=True):
from modules import shared, sd_models
from modules_forge import main_entry
should_refresh_model_loading_params = False
memory_changes = {}
memory_keys = ['forge_inference_memory', 'forge_async_loading', 'forge_pin_shared_memory']
for k, v in req.items():
# ignore unchanged options
if v == shared.opts.data.get(k):
continue
# checkpoints, modules, and options pertaining to memory management are managed in dedicated functions
# If values for these options change, call refresh_model_loading_parameters()
if k == 'sd_model_checkpoint':
if v is not None and v not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {v!r} not found")
main_entry.checkpoint_change(v, save=False, refresh=False)
should_refresh_model_loading_params = True
elif k == 'forge_additional_modules':
should_refresh_model_loading_params = main_entry.modules_change(v, save=False, refresh=False)
elif k in memory_keys:
mem_key = k[len('forge_'):] # remove 'forge_' prefix
memory_changes[mem_key] = v
# set all other options
else:
shared.opts.set(k, v, is_api=is_api, run_callbacks=run_callbacks)
if memory_changes:
main_entry.refresh_memory_management_settings(**memory_changes)
should_refresh_model_loading_params = True
if should_refresh_model_loading_params:
main_entry.refresh_model_loading_parameters()
if save_config:
shared.opts.save(shared.config_filename)

View File

@@ -239,28 +239,35 @@ def refresh_model_loading_parameters():
return
def checkpoint_change(ckpt_name, refresh_params=True):
def checkpoint_change(ckpt_name:str, save=True, refresh=True):
shared.opts.set('sd_model_checkpoint', ckpt_name)
shared.opts.save(shared.config_filename)
if refresh_params:
if save:
shared.opts.save(shared.config_filename)
if refresh:
refresh_model_loading_parameters()
return
def modules_change(module_names, refresh_params=True):
def modules_change(module_values:list, save=True, refresh=True) -> bool:
""" module values may be provided as file paths or as simply the module names """
modules = []
for n in module_names:
if n in module_list:
modules.append(module_list[n])
for v in module_values:
module_name = os.path.basename(v) # If the input is a filepath, extract the file name
if module_name in module_list:
modules.append(module_list[module_name])
# skip further processing if value unchanged
if modules == shared.opts.data.get('forge_additional_modules'):
return False
shared.opts.set('forge_additional_modules', modules)
shared.opts.save(shared.config_filename)
if refresh_params:
if save:
shared.opts.save(shared.config_filename)
if refresh:
refresh_model_loading_parameters()
return
return True
def get_a1111_ui_component(tab, label):