Fixed issue with key mapping from diffusers back to ldm

This commit is contained in:
Jaret Burkett
2023-08-28 14:01:26 -06:00
parent c446f768ea
commit fab7c2b04a
2 changed files with 146 additions and 4 deletions

View File

@@ -47,15 +47,15 @@ def convert_state_dict_to_ldm_with_mapping(
# process operators first
for ldm_key in ldm_diffusers_operator_map:
# if the key cat is in the ldm key, we need to process it
if 'cat' in ldm_key:
if 'cat' in ldm_diffusers_operator_map[ldm_key]:
cat_list = []
for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']:
cat_list.append(diffusers_state_dict[diffusers_key].detatch())
cat_list.append(diffusers_state_dict[diffusers_key].detach())
converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype)
if 'slice' in ldm_key:
if 'slice' in ldm_diffusers_operator_map[ldm_key]:
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detatch().to(device,
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device,
dtype=dtype)
# process the rest of the keys