Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.

This commit is contained in:
Jaret Burkett
2023-11-03 05:01:16 -06:00
parent ceaf1d9454
commit d35733ac06
8 changed files with 3569 additions and 75 deletions

View File

@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
def get_ldm_state_dict_from_diffusers(
state_dict: 'OrderedDict',
sd_version: Literal['1', '2', 'sdxl'] = '2',
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2',
device='cpu',
dtype=get_torch_dtype('fp32'),
):
@@ -111,6 +111,10 @@ def get_ldm_state_dict_from_diffusers(
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
elif sd_version == 'ssd':
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
else:
raise ValueError(f"Invalid sd_version {sd_version}")
@@ -129,7 +133,7 @@ def save_ldm_model_from_diffusers(
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
):
converted_state_dict = get_ldm_state_dict_from_diffusers(
sd.state_dict(),
@@ -148,13 +152,14 @@ def save_lora_from_diffusers(
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
):
converted_state_dict = OrderedDict()
# only handle sxdxl for now
if sd_version != 'sdxl':
if sd_version != 'sdxl' and sd_version != 'ssd':
raise ValueError(f"Invalid sd_version {sd_version}")
for key, value in lora_state_dict.items():
# todo verify if this works with ssd
# test encoders share keys for some reason
if key.begins_with('lora_te'):
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)