mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
fixed issues with converting and saving models. Cleaned keys. Improved testing for cycle load saving.
This commit is contained in:
@@ -32,6 +32,10 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
with open(mapping_path, 'r') as f:
|
||||
mapping = json.load(f, object_pairs_hook=OrderedDict)
|
||||
|
||||
# keep track of keys not matched
|
||||
ldm_matched_keys = []
|
||||
diffusers_matched_keys = []
|
||||
|
||||
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
|
||||
ldm_diffusers_shape_map = mapping['ldm_diffusers_shape_map']
|
||||
ldm_diffusers_operator_map = mapping['ldm_diffusers_operator_map']
|
||||
@@ -52,11 +56,15 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']:
|
||||
cat_list.append(diffusers_state_dict[diffusers_key].detach())
|
||||
converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype)
|
||||
diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['cat'])
|
||||
ldm_matched_keys.append(ldm_key)
|
||||
if 'slice' in ldm_diffusers_operator_map[ldm_key]:
|
||||
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
|
||||
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
|
||||
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device,
|
||||
dtype=dtype)
|
||||
diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['slice'])
|
||||
ldm_matched_keys.append(ldm_key)
|
||||
|
||||
# process the rest of the keys
|
||||
for ldm_key in ldm_diffusers_keymap:
|
||||
@@ -67,6 +75,22 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
if ldm_key in ldm_diffusers_shape_map:
|
||||
tensor = tensor.view(ldm_diffusers_shape_map[ldm_key][0])
|
||||
converted_state_dict[ldm_key] = tensor
|
||||
diffusers_matched_keys.append(ldm_diffusers_keymap[ldm_key])
|
||||
ldm_matched_keys.append(ldm_key)
|
||||
|
||||
# see if any are missing from know mapping
|
||||
mapped_diffusers_keys = list(ldm_diffusers_keymap.values())
|
||||
mapped_ldm_keys = list(ldm_diffusers_keymap.keys())
|
||||
|
||||
missing_diffusers_keys = [x for x in mapped_diffusers_keys if x not in diffusers_matched_keys]
|
||||
missing_ldm_keys = [x for x in mapped_ldm_keys if x not in ldm_matched_keys]
|
||||
|
||||
if len(missing_diffusers_keys) > 0:
|
||||
print(f"WARNING!!!! Missing {len(missing_diffusers_keys)} diffusers keys")
|
||||
print(missing_diffusers_keys)
|
||||
if len(missing_ldm_keys) > 0:
|
||||
print(f"WARNING!!!! Missing {len(missing_ldm_keys)} ldm keys")
|
||||
print(missing_ldm_keys)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user