Added converters for all stable diffusion models to convert back to ldm format from diffusers.

This commit is contained in:
Jaret Burkett
2023-08-28 16:12:32 -06:00
parent fab7c2b04a
commit bee0b6a235
9 changed files with 5120 additions and 150 deletions

View File

@@ -56,7 +56,7 @@ def convert_state_dict_to_ldm_with_mapping(
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device,
dtype=dtype)
dtype=dtype)
# process the rest of the keys
for ldm_key in ldm_diffusers_keymap:
@@ -71,6 +71,35 @@ def convert_state_dict_to_ldm_with_mapping(
return converted_state_dict
def get_ldm_state_dict_from_diffusers(
state_dict: 'OrderedDict',
sd_version: Literal['1', '2', 'sdxl'] = '2',
device='cpu',
dtype=get_torch_dtype('fp32'),
):
if sd_version == '1':
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1.json')
elif sd_version == '2':
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2.json')
elif sd_version == 'sdxl':
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
else:
raise ValueError(f"Invalid sd_version {sd_version}")
# convert the state dict
return convert_state_dict_to_ldm_with_mapping(
state_dict,
mapping_path,
base_path,
device=device,
dtype=dtype
)
def save_ldm_model_from_diffusers(
sd: 'StableDiffusion',
output_file: str,
@@ -78,21 +107,13 @@ def save_ldm_model_from_diffusers(
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
):
if sd_version != 'sdxl':
# not supported yet
raise NotImplementedError("Only SDXL is supported at this time with this method")
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
# convert the state dict
converted_state_dict = convert_state_dict_to_ldm_with_mapping(
converted_state_dict = get_ldm_state_dict_from_diffusers(
sd.state_dict(),
mapping_path,
base_path,
sd_version,
device='cpu',
dtype=save_dtype
)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)