Improve option handling (sd_model_checkpoint / forge_additional_modules) (#2181)

* Sort modules when checking for changes
* Compare consistent checkpoint values
This commit is contained in:
altoiddealer
2024-10-26 16:10:05 -04:00
committed by GitHub
parent d4d8ad406e
commit 145a46907e
3 changed files with 18 additions and 11 deletions

View File

@@ -1402,24 +1402,24 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
reload = False
if 'Use same choices' not in self.hr_additional_modules:
if sorted(self.hr_additional_modules) != sorted(fp_additional_modules):
main_entry.modules_change(self.hr_additional_modules, save=False, refresh=False)
modules_changed = main_entry.modules_change(self.hr_additional_modules, save=False, refresh=False)
if modules_changed:
reload = True
if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
if self.hr_checkpoint_name != fp_checkpoint:
checkpoint_changed = main_entry.checkpoint_change(self.hr_checkpoint_name, save=False, refresh=False)
if checkpoint_changed:
self.firstpass_use_distilled_cfg_scale = self.sd_model.use_distilled_cfg_scale
main_entry.checkpoint_change(self.hr_checkpoint_name, save=False, refresh=False)
reload = True
if reload:
try:
main_entry.refresh_model_loading_parameters()
sd_models.forge_model_reload()
finally:
main_entry.modules_change(fp_additional_modules, save=False, refresh=False)
main_entry.checkpoint_change(fp_checkpoint, save=False)
main_entry.checkpoint_change(fp_checkpoint, save=False, refresh=False)
main_entry.refresh_model_loading_parameters()
if self.sd_model.use_distilled_cfg_scale:
self.extra_generation_params['Hires Distilled CFG Scale'] = self.hr_distilled_cfg

View File

@@ -234,8 +234,9 @@ def set_config(req: dict[str, Any], is_api=False, run_callbacks=True, save_confi
if k == 'sd_model_checkpoint':
if v is not None and v not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {v!r} not found")
main_entry.checkpoint_change(v, save=False, refresh=False)
should_refresh_model_loading_params = True
checkpoint_changed = main_entry.checkpoint_change(v, save=False, refresh=False)
if checkpoint_changed:
should_refresh_model_loading_params = True
elif k == 'forge_additional_modules':
modules_changed = main_entry.modules_change(v, save=False, refresh=False)
if modules_changed:

View File

@@ -240,13 +240,19 @@ def refresh_model_loading_parameters():
def checkpoint_change(ckpt_name:str, save=True, refresh=True):
""" checkpoint name can be a number of valid aliases. Returns True if checkpoint changed. """
new_ckpt_info = sd_models.get_closet_checkpoint_match(ckpt_name)
current_ckpt_info = sd_models.get_closet_checkpoint_match(shared.opts.data.get('sd_model_checkpoint', ''))
if new_ckpt_info == current_ckpt_info:
return False
shared.opts.set('sd_model_checkpoint', ckpt_name)
if save:
shared.opts.save(shared.config_filename)
if refresh:
refresh_model_loading_parameters()
return
return True
def modules_change(module_values:list, save=True, refresh=True) -> bool:
@@ -258,7 +264,7 @@ def modules_change(module_values:list, save=True, refresh=True) -> bool:
modules.append(module_list[module_name])
# skip further processing if value unchanged
if modules == shared.opts.data.get('forge_additional_modules'):
if sorted(modules) == sorted(shared.opts.data.get('forge_additional_modules', [])):
return False
shared.opts.set('forge_additional_modules', modules)