Corrected key saving and loading to better match kohya

This commit is contained in:
Jaret Burkett
2023-09-04 00:22:34 -06:00
parent 22ed539321
commit fa8fc32c0a
5 changed files with 3371 additions and 4 deletions

View File

@@ -141,3 +141,30 @@ def save_ldm_model_from_diffusers(
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
def save_lora_from_diffusers(
lora_state_dict: 'OrderedDict',
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
):
converted_state_dict = OrderedDict()
# only handle sxdxl for now
if sd_version != 'sdxl':
raise ValueError(f"Invalid sd_version {sd_version}")
for key, value in lora_state_dict.items():
# test encoders share keys for some reason
if key.begins_with('lora_te'):
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
else:
converted_key = key
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta
)