From 145a46907e9208183d9f3d892e32fbf9c25f3a2a Mon Sep 17 00:00:00 2001 From: altoiddealer Date: Sat, 26 Oct 2024 16:10:05 -0400 Subject: [PATCH] Improve option handling (sd_model_checkpoint / forge_additional_modules) (#2181) * Sort modules when checking for changes * Compare consistent checkpoint values --- modules/processing.py | 14 +++++++------- modules/sysinfo.py | 5 +++-- modules_forge/main_entry.py | 10 ++++++++-- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index acc2718b..b96bbd74 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1402,24 +1402,24 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): reload = False if 'Use same choices' not in self.hr_additional_modules: - if sorted(self.hr_additional_modules) != sorted(fp_additional_modules): - main_entry.modules_change(self.hr_additional_modules, save=False, refresh=False) + modules_changed = main_entry.modules_change(self.hr_additional_modules, save=False, refresh=False) + if modules_changed: reload = True if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint': - if self.hr_checkpoint_name != fp_checkpoint: + checkpoint_changed = main_entry.checkpoint_change(self.hr_checkpoint_name, save=False, refresh=False) + if checkpoint_changed: self.firstpass_use_distilled_cfg_scale = self.sd_model.use_distilled_cfg_scale - - main_entry.checkpoint_change(self.hr_checkpoint_name, save=False, refresh=False) reload = True - + if reload: try: main_entry.refresh_model_loading_parameters() sd_models.forge_model_reload() finally: main_entry.modules_change(fp_additional_modules, save=False, refresh=False) - main_entry.checkpoint_change(fp_checkpoint, save=False) + main_entry.checkpoint_change(fp_checkpoint, save=False, refresh=False) + main_entry.refresh_model_loading_parameters() if self.sd_model.use_distilled_cfg_scale: self.extra_generation_params['Hires Distilled CFG Scale'] = self.hr_distilled_cfg diff --git a/modules/sysinfo.py b/modules/sysinfo.py index cb6b523f..7533de12 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -234,8 +234,9 @@ def set_config(req: dict[str, Any], is_api=False, run_callbacks=True, save_confi if k == 'sd_model_checkpoint': if v is not None and v not in sd_models.checkpoint_aliases: raise RuntimeError(f"model {v!r} not found") - main_entry.checkpoint_change(v, save=False, refresh=False) - should_refresh_model_loading_params = True + checkpoint_changed = main_entry.checkpoint_change(v, save=False, refresh=False) + if checkpoint_changed: + should_refresh_model_loading_params = True elif k == 'forge_additional_modules': modules_changed = main_entry.modules_change(v, save=False, refresh=False) if modules_changed: diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index b3bef87c..dcee2129 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -240,13 +240,19 @@ def refresh_model_loading_parameters(): def checkpoint_change(ckpt_name:str, save=True, refresh=True): + """ checkpoint name can be a number of valid aliases. Returns True if checkpoint changed. """ + new_ckpt_info = sd_models.get_closet_checkpoint_match(ckpt_name) + current_ckpt_info = sd_models.get_closet_checkpoint_match(shared.opts.data.get('sd_model_checkpoint', '')) + if new_ckpt_info == current_ckpt_info: + return False + shared.opts.set('sd_model_checkpoint', ckpt_name) if save: shared.opts.save(shared.config_filename) if refresh: refresh_model_loading_parameters() - return + return True def modules_change(module_values:list, save=True, refresh=True) -> bool: @@ -258,7 +264,7 @@ def modules_change(module_values:list, save=True, refresh=True) -> bool: modules.append(module_list[module_name]) # skip further processing if value unchanged - if modules == shared.opts.data.get('forge_additional_modules'): + if sorted(modules) == sorted(shared.opts.data.get('forge_additional_modules', [])): return False shared.opts.set('forge_additional_modules', modules)