mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
92 lines
2.4 KiB
Python
92 lines
2.4 KiB
Python
# currently only works with flux as support is not quite there yet
|
|
|
|
import argparse
|
|
import os.path
|
|
from collections import OrderedDict
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'input_path',
|
|
type=str,
|
|
help='Path to original sdxl model'
|
|
)
|
|
parser.add_argument(
|
|
'output_path',
|
|
type=str,
|
|
help='output path'
|
|
)
|
|
args = parser.parse_args()
|
|
args.input_path = os.path.abspath(args.input_path)
|
|
args.output_path = os.path.abspath(args.output_path)
|
|
|
|
from safetensors.torch import load_file, save_file
|
|
|
|
meta = OrderedDict()
|
|
meta['format'] = 'pt'
|
|
|
|
state_dict = load_file(args.input_path)
|
|
|
|
# peft doesnt have an alpha so we need to scale the weights
|
|
alpha_keys = [
|
|
'lora_transformer_single_transformer_blocks_0_attn_to_q.alpha' # flux
|
|
]
|
|
|
|
# keys where the rank is in the first dimension
|
|
rank_idx0_keys = [
|
|
'lora_transformer_single_transformer_blocks_0_attn_to_q.lora_down.weight'
|
|
# 'transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight'
|
|
]
|
|
|
|
alpha = None
|
|
rank = None
|
|
|
|
for key in rank_idx0_keys:
|
|
if key in state_dict:
|
|
rank = int(state_dict[key].shape[0])
|
|
break
|
|
|
|
if rank is None:
|
|
raise ValueError(f'Could not find rank in state dict')
|
|
|
|
for key in alpha_keys:
|
|
if key in state_dict:
|
|
alpha = int(state_dict[key])
|
|
break
|
|
|
|
if alpha is None:
|
|
# set to rank if not found
|
|
alpha = rank
|
|
|
|
|
|
up_multiplier = alpha / rank
|
|
|
|
new_state_dict = {}
|
|
|
|
for key, value in state_dict.items():
|
|
if key.endswith('.alpha'):
|
|
continue
|
|
|
|
orig_dtype = value.dtype
|
|
|
|
new_val = value.float() * up_multiplier
|
|
|
|
new_key = key
|
|
new_key = new_key.replace('lora_transformer_', 'transformer.')
|
|
for i in range(100):
|
|
new_key = new_key.replace(f'transformer_blocks_{i}_', f'transformer_blocks.{i}.')
|
|
new_key = new_key.replace('lora_down', 'lora_A')
|
|
new_key = new_key.replace('lora_up', 'lora_B')
|
|
new_key = new_key.replace('_lora', '.lora')
|
|
new_key = new_key.replace('attn_', 'attn.')
|
|
new_key = new_key.replace('ff_', 'ff.')
|
|
new_key = new_key.replace('context_net_', 'context.net.')
|
|
new_key = new_key.replace('0_proj', '0.proj')
|
|
new_key = new_key.replace('norm_linear', 'norm.linear')
|
|
new_key = new_key.replace('norm_out_linear', 'norm_out.linear')
|
|
new_key = new_key.replace('to_out_', 'to_out.')
|
|
|
|
new_state_dict[new_key] = new_val.to(orig_dtype)
|
|
|
|
save_file(new_state_dict, args.output_path, meta)
|
|
print(f'Saved to {args.output_path}')
|