diff --git a/modules/api/api.py b/modules/api/api.py index f15231ba..14359736 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -19,7 +19,6 @@ from secrets import compare_digest import modules.shared as shared from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers from modules.api import models -from modules_forge import main_entry from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images, process_extra_images from modules.textual_inversion.textual_inversion import create_embedding @@ -675,42 +674,12 @@ class Api: shared.state.skip() def get_config(self): - options = {} - for key in shared.opts.data.keys(): - metadata = shared.opts.data_labels.get(key) - if(metadata is not None): - options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)}) - else: - options.update({key: shared.opts.data.get(key, None)}) - - return options + from modules.sysinfo import get_config + return get_config() def set_config(self, req: dict[str, Any]): - checkpoint_name = req.get("sd_model_checkpoint", None) - if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: - raise RuntimeError(f"model {checkpoint_name!r} not found") - - memory_changes = {} - memory_keys = ['forge_inference_memory', 'forge_async_loading', 'forge_pin_shared_memory'] - - for k, v in req.items(): - # options for memory/modules are set in their dedicated functions - if k in memory_keys: - mem_key = k[len('forge_'):] # remove 'forge_' prefix - memory_changes[mem_key] = v - elif k == 'forge_additional_modules': - main_entry.modules_change(v, refresh_params=False) # refresh_model_loading_parameters() --- applied in checkpoint_change() - # set all other options - else: - shared.opts.set(k, v, is_api=True) - - main_entry.checkpoint_change(checkpoint_name) - # shared.opts.save(shared.config_filename) --- applied in checkpoint_change() - - if memory_changes: - main_entry.refresh_memory_management_settings(**memory_changes) - - return + from modules.sysinfo import set_config + set_config(req) def get_cmd_flags(self): return vars(shared.cmd_opts) diff --git a/modules/processing.py b/modules/processing.py index 8cf424dc..e27ed7b0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -20,6 +20,7 @@ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infot from modules.rng import slerp, get_noise_source_type # noqa: F401 from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.shared import opts, cmd_opts, state +from modules.sysinfo import set_config import modules.shared as shared import modules.paths as paths import modules.face_restoration @@ -818,24 +819,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: p.override_settings.pop('sd_model_checkpoint', None) - temp_memory_changes = {} - memory_keys = ['forge_inference_memory', 'forge_async_loading', 'forge_pin_shared_memory'] - - for k, v in p.override_settings.items(): - # options for memory/modules/checkpoints are set in their dedicated functions - if k in memory_keys: - mem_k = k[len('forge_'):] # remove 'forge_' prefix - temp_memory_changes[mem_k] = v - elif k == 'forge_additional_modules': - main_entry.modules_change(v) - elif k == 'sd_model_checkpoint': - main_entry.checkpoint_change(v) - # set all other options - else: - opts.set(k, v, is_api=True, run_callbacks=False) - - if temp_memory_changes: - main_entry.refresh_memory_management_settings(**temp_memory_changes) + # apply any options overrides + set_config(p.override_settings, is_api=True, run_callbacks=False, save_config=False) # load/reload model and manage prompt cache as needed manage_model_and_prompt_cache(p) @@ -850,18 +835,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: res = process_images_inner(p) finally: - # restore opts to original state + # restore original options if p.override_settings_restore_afterwards: - for k, v in stored_opts.items(): - if k == 'forge_additional_modules': - main_entry.modules_change(v) - elif k == 'sd_model_checkpoint': - main_entry.checkpoint_change(v) - else: - setattr(opts, k, v) - - if temp_memory_changes: - main_entry.refresh_memory_management_settings() # applies the set options by default + set_config(stored_opts, save_config=False) return res diff --git a/modules/sysinfo.py b/modules/sysinfo.py index e9a83d74..a3a7579c 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -6,6 +6,7 @@ import platform import hashlib import re from pathlib import Path +from typing import Any from modules import paths_internal, timer, shared_cmd_options, errors, launch_utils @@ -213,3 +214,44 @@ def get_config(): return json.load(f) except Exception as e: return str(e) + +def set_config(req: dict[str, Any], is_api=False, run_callbacks=True, save_config=True): + from modules import shared, sd_models + from modules_forge import main_entry + + should_refresh_model_loading_params = False + + memory_changes = {} + memory_keys = ['forge_inference_memory', 'forge_async_loading', 'forge_pin_shared_memory'] + + for k, v in req.items(): + # ignore unchanged options + if v == shared.opts.data.get(k): + continue + + # checkpoints, modules, and options pertaining to memory management are managed in dedicated functions + # If values for these options change, call refresh_model_loading_parameters() + if k == 'sd_model_checkpoint': + if v is not None and v not in sd_models.checkpoint_aliases: + raise RuntimeError(f"model {v!r} not found") + main_entry.checkpoint_change(v, save=False, refresh=False) + should_refresh_model_loading_params = True + elif k == 'forge_additional_modules': + should_refresh_model_loading_params = main_entry.modules_change(v, save=False, refresh=False) + elif k in memory_keys: + mem_key = k[len('forge_'):] # remove 'forge_' prefix + memory_changes[mem_key] = v + + # set all other options + else: + shared.opts.set(k, v, is_api=is_api, run_callbacks=run_callbacks) + + if memory_changes: + main_entry.refresh_memory_management_settings(**memory_changes) + should_refresh_model_loading_params = True + + if should_refresh_model_loading_params: + main_entry.refresh_model_loading_parameters() + + if save_config: + shared.opts.save(shared.config_filename) diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index c8f4373d..01d1d279 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -239,28 +239,35 @@ def refresh_model_loading_parameters(): return -def checkpoint_change(ckpt_name, refresh_params=True): +def checkpoint_change(ckpt_name:str, save=True, refresh=True): shared.opts.set('sd_model_checkpoint', ckpt_name) - shared.opts.save(shared.config_filename) - if refresh_params: + if save: + shared.opts.save(shared.config_filename) + if refresh: refresh_model_loading_parameters() return -def modules_change(module_names, refresh_params=True): +def modules_change(module_values:list, save=True, refresh=True) -> bool: + """ module values may be provided as file paths or as simply the module names """ modules = [] - - for n in module_names: - if n in module_list: - modules.append(module_list[n]) + for v in module_values: + module_name = os.path.basename(v) # If the input is a filepath, extract the file name + if module_name in module_list: + modules.append(module_list[module_name]) + + # skip further processing if value unchanged + if modules == shared.opts.data.get('forge_additional_modules'): + return False shared.opts.set('forge_additional_modules', modules) - shared.opts.save(shared.config_filename) - if refresh_params: + if save: + shared.opts.save(shared.config_filename) + if refresh: refresh_model_loading_parameters() - return + return True def get_a1111_ui_component(tab, label):