prevent undoing refresh model load params (#2092)

Ensures `should_refresh_model_loading_params()` is called when needed. Improved code clarity.
This commit is contained in:
altoiddealer
2024-10-16 09:52:35 -04:00
committed by GitHub
parent 9efa4eabfd
commit 6dc71b7e1d
2 changed files with 4 additions and 2 deletions

View File

@@ -237,7 +237,9 @@ def set_config(req: dict[str, Any], is_api=False, run_callbacks=True, save_confi
main_entry.checkpoint_change(v, save=False, refresh=False) main_entry.checkpoint_change(v, save=False, refresh=False)
should_refresh_model_loading_params = True should_refresh_model_loading_params = True
elif k == 'forge_additional_modules': elif k == 'forge_additional_modules':
should_refresh_model_loading_params = main_entry.modules_change(v, save=False, refresh=False) modules_changed = main_entry.modules_change(v, save=False, refresh=False)
if modules_changed:
should_refresh_model_loading_params = True
elif k in memory_keys: elif k in memory_keys:
mem_key = k[len('forge_'):] # remove 'forge_' prefix mem_key = k[len('forge_'):] # remove 'forge_' prefix
memory_changes[mem_key] = v memory_changes[mem_key] = v

View File

@@ -250,7 +250,7 @@ def checkpoint_change(ckpt_name:str, save=True, refresh=True):
def modules_change(module_values:list, save=True, refresh=True) -> bool: def modules_change(module_values:list, save=True, refresh=True) -> bool:
""" module values may be provided as file paths or as simply the module names """ """ module values may be provided as file paths, or just the module names. Returns True if modules changed. """
modules = [] modules = []
for v in module_values: for v in module_values:
module_name = os.path.basename(v) # If the input is a filepath, extract the file name module_name = os.path.basename(v) # If the input is a filepath, extract the file name