Fix Checkpoint Merging #1359,#1095 (#1454)

- checkpoint_list[] contains the CheckpointInfo.title which is "checkpointname.safetensor [hash]"
  when a checkpoint is selected to be loaded during merge, we try to match it with just "checkpointname.safetensor".
  -> use checkpoint_aliases[] which already contains the checkpoint key in all possible variants.
- replaced removed sd_models.read_state_dict() with sd_models.load_torch_file()
- replaced removed sd_vae.load_vae_dict() with sd_vae.load_torch_file()
- uncommented create_config() for now, since it calls a removed method: sd_models_config.find_checkpoint_config_near_filename()
This commit is contained in:
cluder
2024-08-27 00:39:37 +02:00
committed by GitHub
parent d55e6b5bfe
commit e94130a82a

View File

@@ -132,17 +132,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if not primary_model_name:
return fail("Failed: Merging requires a primary model.")
primary_model_info = sd_models.checkpoints_list[primary_model_name]
primary_model_info = sd_models.checkpoint_aliases[primary_model_name]
if theta_func2 and not secondary_model_name:
return fail("Failed: Merging requires a secondary model.")
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
secondary_model_info = sd_models.checkpoint_aliases[secondary_model_name] if theta_func2 else None
if theta_func1 and not tertiary_model_name:
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
tertiary_model_info = sd_models.checkpoint_aliases[tertiary_model_name] if theta_func1 else None
result_is_inpainting_model = False
result_is_instruct_pix2pix_model = False
@@ -179,7 +179,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
theta_0 = sd_models.load_torch_file(primary_model_info.filename)
print("Merging...")
shared.state.textinfo = 'Merging A and B'
@@ -222,7 +222,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if bake_in_vae_filename is not None:
print(f"Baking in VAE from {bake_in_vae_filename}")
shared.state.textinfo = 'Baking in VAE'
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
vae_dict = sd_vae.load_torch_file(bake_in_vae_filename)
for key in vae_dict.keys():
theta_0_key = 'first_stage_model.' + key
@@ -321,7 +321,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if created_model:
created_model.calculate_shorthash()
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
# TODO inside create_config() sd_models_config.find_checkpoint_config_near_filename() is called which has been commented out
#create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
print(f"Checkpoint saved to {output_modelname}.")
shared.state.textinfo = "Checkpoint saved"