mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.
This commit is contained in:
@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
|
||||
def get_ldm_state_dict_from_diffusers(
|
||||
state_dict: 'OrderedDict',
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2',
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2',
|
||||
device='cpu',
|
||||
dtype=get_torch_dtype('fp32'),
|
||||
):
|
||||
@@ -111,6 +111,10 @@ def get_ldm_state_dict_from_diffusers(
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
||||
elif sd_version == 'ssd':
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
|
||||
else:
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
|
||||
@@ -129,7 +133,7 @@ def save_ldm_model_from_diffusers(
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
save_dtype=get_torch_dtype('fp16'),
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
|
||||
):
|
||||
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
||||
sd.state_dict(),
|
||||
@@ -148,13 +152,14 @@ def save_lora_from_diffusers(
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
save_dtype=get_torch_dtype('fp16'),
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
|
||||
):
|
||||
converted_state_dict = OrderedDict()
|
||||
# only handle sxdxl for now
|
||||
if sd_version != 'sdxl':
|
||||
if sd_version != 'sdxl' and sd_version != 'ssd':
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
for key, value in lora_state_dict.items():
|
||||
# todo verify if this works with ssd
|
||||
# test encoders share keys for some reason
|
||||
if key.begins_with('lora_te'):
|
||||
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
|
||||
|
||||
Reference in New Issue
Block a user