diff --git a/modules/api/api.py b/modules/api/api.py index 617867f2..1f62b453 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -687,12 +687,22 @@ class Api: checkpoint_name = req.get("sd_model_checkpoint", None) if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") + + refresh_memory = False + memory_keys = ['forge_inference_memory', 'forge_async_loading', 'forge_pin_shared_memory'] for k, v in req.items(): shared.opts.set(k, v, is_api=True) + if k in memory_keys: + refresh_memory = True main_entry.checkpoint_change(checkpoint_name) # shared.opts.save(shared.config_filename) --- applied in checkpoint_change() + + if refresh_memory: + model_memory = main_entry.total_vram - shared.opts.forge_inference_memory + main_entry.refresh_memory_management_settings(model_memory, shared.opts.forge_async_loading, shared.opts.forge_pin_shared_memory) + return def get_cmd_flags(self):