fixed issues with converting and saving models. Cleaned keys. Improved testing for cycle load saving.

This commit is contained in:
Jaret Burkett
2023-08-29 12:31:19 -06:00
parent 714854ee86
commit 14ff51ceb4
9 changed files with 784 additions and 1568 deletions

View File

@@ -95,6 +95,8 @@ matched_diffusers_keys = []
error_margin = 1e-4
tmp_merge_key = "TMP___MERGE"
te_suffix = ''
proj_pattern_weight = None
proj_pattern_bias = None
@@ -139,7 +141,7 @@ if args.sdxl or args.sd2:
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight")
# make diffusers convertable_dict
diffusers_state_dict[
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.weight"] = new_val
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.weight"] = new_val
# add operator
ldm_operator_map[ldm_key] = {
@@ -148,7 +150,6 @@ if args.sdxl or args.sd2:
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight",
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight",
],
"target": f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.weight"
}
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
@@ -189,7 +190,7 @@ if args.sdxl or args.sd2:
matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias")
# make diffusers convertable_dict
diffusers_state_dict[
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.bias"] = new_val
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.bias"] = new_val
# add operator
ldm_operator_map[ldm_key] = {
@@ -198,7 +199,6 @@ if args.sdxl or args.sd2:
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias",
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias",
],
# "target": f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.bias"
}
# add diffusers operators
@@ -359,13 +359,35 @@ for key in unmatched_ldm_keys:
save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors'))
print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}')
# do cleanup of some left overs and bugs
to_remove = []
for ldm_key, diffusers_key in ldm_diffusers_keymap.items():
# get rid of tmp merge keys used to slicing
if tmp_merge_key in diffusers_key or tmp_merge_key in ldm_key:
to_remove.append(ldm_key)
for key in to_remove:
del ldm_diffusers_keymap[key]
to_remove = []
# remove identical shape mappings. Not sure why they exist but they do
for ldm_key, shape_list in ldm_diffusers_shape_map.items():
# remove identical shape mappings. Not sure why they exist but they do
# convert to json string to make it easier to compare
ldm_shape = json.dumps(shape_list[0])
diffusers_shape = json.dumps(shape_list[1])
if ldm_shape == diffusers_shape:
to_remove.append(ldm_key)
for key in to_remove:
del ldm_diffusers_shape_map[key]
dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json')
save_obj = OrderedDict()
save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap
save_obj["ldm_diffusers_shape_map"] = ldm_diffusers_shape_map
save_obj["ldm_diffusers_operator_map"] = ldm_operator_map
save_obj["diffusers_ldm_operator_map"] = diffusers_operator_map
with open(dest_path, 'w') as f:
f.write(json.dumps(save_obj, indent=4))