prevent undoing refresh model load params (#2092)

Ensures `should_refresh_model_loading_params()` is called when needed. Improved code clarity.
This commit is contained in:
altoiddealer
2024-10-16 09:52:35 -04:00
committed by GitHub
parent 9efa4eabfd
commit 6dc71b7e1d
2 changed files with 4 additions and 2 deletions

View File

@@ -237,7 +237,9 @@ def set_config(req: dict[str, Any], is_api=False, run_callbacks=True, save_confi
main_entry.checkpoint_change(v, save=False, refresh=False)
should_refresh_model_loading_params = True
elif k == 'forge_additional_modules':
should_refresh_model_loading_params = main_entry.modules_change(v, save=False, refresh=False)
modules_changed = main_entry.modules_change(v, save=False, refresh=False)
if modules_changed:
should_refresh_model_loading_params = True
elif k in memory_keys:
mem_key = k[len('forge_'):] # remove 'forge_' prefix
memory_changes[mem_key] = v

View File

@@ -250,7 +250,7 @@ def checkpoint_change(ckpt_name:str, save=True, refresh=True):
def modules_change(module_values:list, save=True, refresh=True) -> bool:
""" module values may be provided as file paths or as simply the module names """
""" module values may be provided as file paths, or just the module names. Returns True if modules changed. """
modules = []
for v in module_values:
module_name = os.path.basename(v) # If the input is a filepath, extract the file name