mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added converters for all stable diffusion models to convert back to ldm format from diffusers.
This commit is contained in:
@@ -56,7 +56,7 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
|
||||
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
|
||||
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device,
|
||||
dtype=dtype)
|
||||
dtype=dtype)
|
||||
|
||||
# process the rest of the keys
|
||||
for ldm_key in ldm_diffusers_keymap:
|
||||
@@ -71,6 +71,35 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_ldm_state_dict_from_diffusers(
|
||||
state_dict: 'OrderedDict',
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2',
|
||||
device='cpu',
|
||||
dtype=get_torch_dtype('fp32'),
|
||||
):
|
||||
if sd_version == '1':
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1.json')
|
||||
elif sd_version == '2':
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2.json')
|
||||
elif sd_version == 'sdxl':
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
||||
else:
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
|
||||
# convert the state dict
|
||||
return convert_state_dict_to_ldm_with_mapping(
|
||||
state_dict,
|
||||
mapping_path,
|
||||
base_path,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
|
||||
def save_ldm_model_from_diffusers(
|
||||
sd: 'StableDiffusion',
|
||||
output_file: str,
|
||||
@@ -78,21 +107,13 @@ def save_ldm_model_from_diffusers(
|
||||
save_dtype=get_torch_dtype('fp16'),
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
||||
):
|
||||
if sd_version != 'sdxl':
|
||||
# not supported yet
|
||||
raise NotImplementedError("Only SDXL is supported at this time with this method")
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
||||
|
||||
# convert the state dict
|
||||
converted_state_dict = convert_state_dict_to_ldm_with_mapping(
|
||||
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
||||
sd.state_dict(),
|
||||
mapping_path,
|
||||
base_path,
|
||||
sd_version,
|
||||
device='cpu',
|
||||
dtype=save_dtype
|
||||
)
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
Reference in New Issue
Block a user