fixed issues with converting and saving models. Cleaned keys. Improved testing for cycle load saving.

This commit is contained in:
Jaret Burkett
2023-08-29 12:31:19 -06:00
parent 714854ee86
commit 14ff51ceb4
9 changed files with 784 additions and 1568 deletions

View File

@@ -32,6 +32,10 @@ def convert_state_dict_to_ldm_with_mapping(
with open(mapping_path, 'r') as f:
mapping = json.load(f, object_pairs_hook=OrderedDict)
# keep track of keys not matched
ldm_matched_keys = []
diffusers_matched_keys = []
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
ldm_diffusers_shape_map = mapping['ldm_diffusers_shape_map']
ldm_diffusers_operator_map = mapping['ldm_diffusers_operator_map']
@@ -52,11 +56,15 @@ def convert_state_dict_to_ldm_with_mapping(
for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']:
cat_list.append(diffusers_state_dict[diffusers_key].detach())
converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype)
diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['cat'])
ldm_matched_keys.append(ldm_key)
if 'slice' in ldm_diffusers_operator_map[ldm_key]:
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device,
dtype=dtype)
diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['slice'])
ldm_matched_keys.append(ldm_key)
# process the rest of the keys
for ldm_key in ldm_diffusers_keymap:
@@ -67,6 +75,22 @@ def convert_state_dict_to_ldm_with_mapping(
if ldm_key in ldm_diffusers_shape_map:
tensor = tensor.view(ldm_diffusers_shape_map[ldm_key][0])
converted_state_dict[ldm_key] = tensor
diffusers_matched_keys.append(ldm_diffusers_keymap[ldm_key])
ldm_matched_keys.append(ldm_key)
# see if any are missing from know mapping
mapped_diffusers_keys = list(ldm_diffusers_keymap.values())
mapped_ldm_keys = list(ldm_diffusers_keymap.keys())
missing_diffusers_keys = [x for x in mapped_diffusers_keys if x not in diffusers_matched_keys]
missing_ldm_keys = [x for x in mapped_ldm_keys if x not in ldm_matched_keys]
if len(missing_diffusers_keys) > 0:
print(f"WARNING!!!! Missing {len(missing_diffusers_keys)} diffusers keys")
print(missing_diffusers_keys)
if len(missing_ldm_keys) > 0:
print(f"WARNING!!!! Missing {len(missing_ldm_keys)} ldm keys")
print(missing_ldm_keys)
return converted_state_dict