mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed issue with key mapping from diffusers back to ldm
This commit is contained in:
@@ -47,15 +47,15 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
# process operators first
|
||||
for ldm_key in ldm_diffusers_operator_map:
|
||||
# if the key cat is in the ldm key, we need to process it
|
||||
if 'cat' in ldm_key:
|
||||
if 'cat' in ldm_diffusers_operator_map[ldm_key]:
|
||||
cat_list = []
|
||||
for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']:
|
||||
cat_list.append(diffusers_state_dict[diffusers_key].detatch())
|
||||
cat_list.append(diffusers_state_dict[diffusers_key].detach())
|
||||
converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype)
|
||||
if 'slice' in ldm_key:
|
||||
if 'slice' in ldm_diffusers_operator_map[ldm_key]:
|
||||
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
|
||||
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
|
||||
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detatch().to(device,
|
||||
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device,
|
||||
dtype=dtype)
|
||||
|
||||
# process the rest of the keys
|
||||
|
||||
Reference in New Issue
Block a user