From 99f24cfb0c8de876cf7366089298385aea0066fe Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 5 Aug 2024 14:54:10 -0600 Subject: [PATCH] Added a conversion script to convert my loras to peft format for flux --- scripts/convert_lora_to_peft_format.py | 87 ++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 scripts/convert_lora_to_peft_format.py diff --git a/scripts/convert_lora_to_peft_format.py b/scripts/convert_lora_to_peft_format.py new file mode 100644 index 00000000..6d07877e --- /dev/null +++ b/scripts/convert_lora_to_peft_format.py @@ -0,0 +1,87 @@ +# currently only works with flux as support is not quite there yet + +import argparse +import os.path +from collections import OrderedDict + +parser = argparse.ArgumentParser() +parser.add_argument( + 'input_path', + type=str, + help='Path to original sdxl model' +) +parser.add_argument( + 'output_path', + type=str, + help='output path' +) +args = parser.parse_args() +args.input_path = os.path.abspath(args.input_path) +args.output_path = os.path.abspath(args.output_path) + +from safetensors.torch import load_file, save_file + +meta = OrderedDict() +meta['format'] = 'pt' + +state_dict = load_file(args.input_path) + +# peft doesnt have an alpha so we need to scale the weights +alpha_keys = [ + 'lora_transformer_single_transformer_blocks_0_attn_to_q.alpha' # flux +] + +# keys where the rank is in the first dimension +rank_idx0_keys = [ + 'lora_transformer_single_transformer_blocks_0_attn_to_q.lora_down.weight' + # 'transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight' +] + +alpha = None +rank = None + +for key in rank_idx0_keys: + if key in state_dict: + rank = int(state_dict[key].shape[0]) + break + +if rank is None: + raise ValueError(f'Could not find rank in state dict') + +for key in alpha_keys: + if key in state_dict: + alpha = int(state_dict[key]) + break + +if alpha is None: + # set to rank if not found + alpha = rank + + +up_multiplier = alpha / rank + +new_state_dict = {} + +for key, value in state_dict.items(): + if key.endswith('.alpha'): + continue + + orig_dtype = value.dtype + + new_val = value.float() * up_multiplier + + new_key = key + new_key = new_key.replace('lora_transformer_', 'transformer.') + for i in range(100): + new_key = new_key.replace(f'transformer_blocks_{i}_', f'transformer_blocks.{i}.') + new_key = new_key.replace('lora_down', 'lora_A') + new_key = new_key.replace('lora_up', 'lora_B') + new_key = new_key.replace('_lora', '.lora') + new_key = new_key.replace('attn_', 'attn.') + new_key = new_key.replace('norm_linear', 'norm.linear') + new_key = new_key.replace('norm_out_linear', 'norm_out.linear') + + new_state_dict[new_key] = new_val.to(orig_dtype) + +save_file(new_state_dict, args.output_path, meta) +print(f'Saved to {args.output_path}')