mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Added experimental concept replacer, replicate converter, bucket maker, and other goodies
This commit is contained in:
128
scripts/convert_cog.py
Normal file
128
scripts/convert_cog.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
||||
# [diffusers] -> kohya
|
||||
embedding_mapping = {
|
||||
'text_encoders_0': 'clip_l',
|
||||
'text_encoders_1': 'clip_g'
|
||||
}
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
KEYMAP_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps')
|
||||
sdxl_keymap_path = os.path.join(KEYMAP_ROOT, 'stable_diffusion_locon_sdxl.json')
|
||||
|
||||
# load keymap
|
||||
with open(sdxl_keymap_path, 'r') as f:
|
||||
ldm_diffusers_keymap = json.load(f)['ldm_diffusers_keymap']
|
||||
|
||||
# invert the item / key pairs
|
||||
diffusers_ldm_keymap = {v: k for k, v in ldm_diffusers_keymap.items()}
|
||||
|
||||
|
||||
def get_ldm_key(diffuser_key):
|
||||
diffuser_key = f"lora_unet_{diffuser_key.replace('.', '_')}"
|
||||
diffuser_key = diffuser_key.replace('_lora_down_weight', '.lora_down.weight')
|
||||
diffuser_key = diffuser_key.replace('_lora_up_weight', '.lora_up.weight')
|
||||
diffuser_key = diffuser_key.replace('_alpha', '.alpha')
|
||||
diffuser_key = diffuser_key.replace('_processor_to_', '_to_')
|
||||
diffuser_key = diffuser_key.replace('_to_out.', '_to_out_0.')
|
||||
if diffuser_key in diffusers_ldm_keymap:
|
||||
return diffusers_ldm_keymap[diffuser_key]
|
||||
else:
|
||||
raise KeyError(f"Key {diffuser_key} not found in keymap")
|
||||
|
||||
|
||||
def convert_cog(lora_path, embedding_path):
|
||||
embedding_state_dict = OrderedDict()
|
||||
lora_state_dict = OrderedDict()
|
||||
|
||||
# # normal dict
|
||||
# normal_dict = OrderedDict()
|
||||
# example_path = "/mnt/Models/stable-diffusion/models/LoRA/sdxl/LogoRedmond_LogoRedAF.safetensors"
|
||||
# with safe_open(example_path, framework="pt", device='cpu') as f:
|
||||
# keys = list(f.keys())
|
||||
# for key in keys:
|
||||
# normal_dict[key] = f.get_tensor(key)
|
||||
|
||||
with safe_open(embedding_path, framework="pt", device='cpu') as f:
|
||||
keys = list(f.keys())
|
||||
for key in keys:
|
||||
new_key = embedding_mapping[key]
|
||||
embedding_state_dict[new_key] = f.get_tensor(key)
|
||||
|
||||
with safe_open(lora_path, framework="pt", device='cpu') as f:
|
||||
keys = list(f.keys())
|
||||
lora_rank = None
|
||||
|
||||
# get the lora dim first. Check first 3 linear layers just to be safe
|
||||
for key in keys:
|
||||
new_key = get_ldm_key(key)
|
||||
tensor = f.get_tensor(key)
|
||||
num_checked = 0
|
||||
if len(tensor.shape) == 2:
|
||||
this_dim = min(tensor.shape)
|
||||
if lora_rank is None:
|
||||
lora_rank = this_dim
|
||||
elif lora_rank != this_dim:
|
||||
raise ValueError(f"lora rank is not consistent, got {tensor.shape}")
|
||||
else:
|
||||
num_checked += 1
|
||||
if num_checked >= 3:
|
||||
break
|
||||
|
||||
for key in keys:
|
||||
new_key = get_ldm_key(key)
|
||||
tensor = f.get_tensor(key)
|
||||
if new_key.endswith('.lora_down.weight'):
|
||||
alpha_key = new_key.replace('.lora_down.weight', '.alpha')
|
||||
# diffusers does not have alpha, they usa an alpha multiplier of 1 which is a tensor weight of the dims
|
||||
# assume first smallest dim is the lora rank if shape is 2
|
||||
lora_state_dict[alpha_key] = torch.ones(1).to(tensor.device, tensor.dtype) * lora_rank
|
||||
|
||||
lora_state_dict[new_key] = tensor
|
||||
|
||||
return lora_state_dict, embedding_state_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'lora_path',
|
||||
type=str,
|
||||
help='Path to lora file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'embedding_path',
|
||||
type=str,
|
||||
help='Path to embedding file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--lora_output',
|
||||
type=str,
|
||||
default="lora_output",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embedding_output',
|
||||
type=str,
|
||||
default="embedding_output",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
lora_state_dict, embedding_state_dict = convert_cog(args.lora_path, args.embedding_path)
|
||||
|
||||
# save them
|
||||
save_file(lora_state_dict, args.lora_output)
|
||||
save_file(embedding_state_dict, args.embedding_output)
|
||||
print(f"Saved lora to {args.lora_output}")
|
||||
print(f"Saved embedding to {args.embedding_output}")
|
||||
Reference in New Issue
Block a user