From e94130a82aca4a7df3a203da1bb0c22aa34294f8 Mon Sep 17 00:00:00 2001 From: cluder <1590330+cluder@users.noreply.github.com> Date: Tue, 27 Aug 2024 00:39:37 +0200 Subject: [PATCH] Fix Checkpoint Merging #1359,#1095 (#1454) - checkpoint_list[] contains the CheckpointInfo.title which is "checkpointname.safetensor [hash]" when a checkpoint is selected to be loaded during merge, we try to match it with just "checkpointname.safetensor". -> use checkpoint_aliases[] which already contains the checkpoint key in all possible variants. - replaced removed sd_models.read_state_dict() with sd_models.load_torch_file() - replaced removed sd_vae.load_vae_dict() with sd_vae.load_torch_file() - uncommented create_config() for now, since it calls a removed method: sd_models_config.find_checkpoint_config_near_filename() --- modules/extras.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 2a310ae3..6f05b0fe 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -132,17 +132,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if not primary_model_name: return fail("Failed: Merging requires a primary model.") - primary_model_info = sd_models.checkpoints_list[primary_model_name] + primary_model_info = sd_models.checkpoint_aliases[primary_model_name] if theta_func2 and not secondary_model_name: return fail("Failed: Merging requires a secondary model.") - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None + secondary_model_info = sd_models.checkpoint_aliases[secondary_model_name] if theta_func2 else None if theta_func1 and not tertiary_model_name: return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None + tertiary_model_info = sd_models.checkpoint_aliases[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False result_is_instruct_pix2pix_model = False @@ -179,7 +179,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = f"Loading {primary_model_info.filename}..." print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') + theta_0 = sd_models.load_torch_file(primary_model_info.filename) print("Merging...") shared.state.textinfo = 'Merging A and B' @@ -222,7 +222,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if bake_in_vae_filename is not None: print(f"Baking in VAE from {bake_in_vae_filename}") shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') + vae_dict = sd_vae.load_torch_file(bake_in_vae_filename) for key in vae_dict.keys(): theta_0_key = 'first_stage_model.' + key @@ -321,7 +321,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if created_model: created_model.calculate_shorthash() - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + # TODO inside create_config() sd_models_config.find_checkpoint_config_near_filename() is called which has been commented out + #create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) print(f"Checkpoint saved to {output_modelname}.") shared.state.textinfo = "Checkpoint saved"