From 6b88c8273356a34e7abdbeda6d6b0e7008e3e9dd Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 06:34:28 -0800 Subject: [PATCH] Update sd_models.py --- modules/sd_models.py | 62 +------------------------------------------- 1 file changed, 1 insertion(+), 61 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index b7c1d90d..22d28014 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -638,67 +638,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): def reload_model_weights(sd_model=None, info=None, forced_reload=False): - checkpoint_info = info or select_checkpoint() - - timer = Timer() - - if not sd_model: - sd_model = model_data.sd_model - - if sd_model is None: # previous model load failed - current_checkpoint_info = None - else: - current_checkpoint_info = sd_model.sd_checkpoint_info - if check_fp8(sd_model) != devices.fp8: - # load from state dict again to prevent extra numerical errors - forced_reload = True - elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload: - return sd_model - - sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) - if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: - return sd_model - - if sd_model is not None: - sd_unet.apply_unet("None") - sd_hijack.model_hijack.undo_hijack(sd_model) - - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - - checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - - timer.record("find config") - - if sd_model is None or checkpoint_config != sd_model.used_config: - if sd_model is not None: - sd_model = None - - load_model(checkpoint_info, already_loaded_state_dict=state_dict) - return model_data.sd_model - - try: - load_model_weights(sd_model, checkpoint_info, state_dict, timer) - except Exception: - print("Failed to load checkpoint, restoring previous") - load_model_weights(sd_model, current_checkpoint_info, None, timer) - raise - finally: - sd_hijack.model_hijack.hijack(sd_model) - timer.record("hijack") - - if not sd_model.lowvram: - sd_model.to(devices.device) - timer.record("move model to device") - - script_callbacks.model_loaded_callback(sd_model) - timer.record("script callbacks") - - print(f"Weights loaded in {timer.summary()}.") - - model_data.set_sd_model(sd_model) - sd_unet.apply_unet() - - return sd_model + return load_model(info) def unload_model_weights(sd_model=None, info=None):