mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
Fix Checkpoint Merging #1359,#1095 (#1454)
- checkpoint_list[] contains the CheckpointInfo.title which is "checkpointname.safetensor [hash]" when a checkpoint is selected to be loaded during merge, we try to match it with just "checkpointname.safetensor". -> use checkpoint_aliases[] which already contains the checkpoint key in all possible variants. - replaced removed sd_models.read_state_dict() with sd_models.load_torch_file() - replaced removed sd_vae.load_vae_dict() with sd_vae.load_torch_file() - uncommented create_config() for now, since it calls a removed method: sd_models_config.find_checkpoint_config_near_filename()
This commit is contained in:
@@ -132,17 +132,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
if not primary_model_name:
|
||||
return fail("Failed: Merging requires a primary model.")
|
||||
|
||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||
primary_model_info = sd_models.checkpoint_aliases[primary_model_name]
|
||||
|
||||
if theta_func2 and not secondary_model_name:
|
||||
return fail("Failed: Merging requires a secondary model.")
|
||||
|
||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
||||
secondary_model_info = sd_models.checkpoint_aliases[secondary_model_name] if theta_func2 else None
|
||||
|
||||
if theta_func1 and not tertiary_model_name:
|
||||
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
||||
|
||||
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
||||
tertiary_model_info = sd_models.checkpoint_aliases[tertiary_model_name] if theta_func1 else None
|
||||
|
||||
result_is_inpainting_model = False
|
||||
result_is_instruct_pix2pix_model = False
|
||||
@@ -179,7 +179,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
|
||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
theta_0 = sd_models.load_torch_file(primary_model_info.filename)
|
||||
|
||||
print("Merging...")
|
||||
shared.state.textinfo = 'Merging A and B'
|
||||
@@ -222,7 +222,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
if bake_in_vae_filename is not None:
|
||||
print(f"Baking in VAE from {bake_in_vae_filename}")
|
||||
shared.state.textinfo = 'Baking in VAE'
|
||||
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
||||
vae_dict = sd_vae.load_torch_file(bake_in_vae_filename)
|
||||
|
||||
for key in vae_dict.keys():
|
||||
theta_0_key = 'first_stage_model.' + key
|
||||
@@ -321,7 +321,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
if created_model:
|
||||
created_model.calculate_shorthash()
|
||||
|
||||
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||
# TODO inside create_config() sd_models_config.find_checkpoint_config_near_filename() is called which has been commented out
|
||||
#create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||
|
||||
print(f"Checkpoint saved to {output_modelname}.")
|
||||
shared.state.textinfo = "Checkpoint saved"
|
||||
|
||||
Reference in New Issue
Block a user