mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Corrected key saving and loading to better match kohya
This commit is contained in:
@@ -452,6 +452,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
train_text_encoder=self.train_config.train_text_encoder,
|
train_text_encoder=self.train_config.train_text_encoder,
|
||||||
conv_lora_dim=self.network_config.conv,
|
conv_lora_dim=self.network_config.conv,
|
||||||
conv_alpha=self.network_config.conv_alpha,
|
conv_alpha=self.network_config.conv_alpha,
|
||||||
|
is_sdxl=self.model_config.is_xl,
|
||||||
|
is_v2=self.model_config.is_v2,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.network.force_to(self.device_torch, dtype=dtype)
|
self.network.force_to(self.device_torch, dtype=dtype)
|
||||||
|
|||||||
130
testing/generate_lora_mapping.py
Normal file
130
testing/generate_lora_mapping.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
keymap_path = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', 'stable_diffusion_sdxl.json')
|
||||||
|
|
||||||
|
# load keymap
|
||||||
|
with open(keymap_path, 'r') as f:
|
||||||
|
keymap = json.load(f)
|
||||||
|
|
||||||
|
lora_keymap = OrderedDict()
|
||||||
|
|
||||||
|
# convert keymap to lora key naming
|
||||||
|
for ldm_key, diffusers_key in keymap['ldm_diffusers_keymap'].items():
|
||||||
|
if ldm_key.endswith('.bias') or diffusers_key.endswith('.bias'):
|
||||||
|
# skip it
|
||||||
|
continue
|
||||||
|
# sdxl has same te for locon with kohya and ours
|
||||||
|
if ldm_key.startswith('conditioner'):
|
||||||
|
#skip it
|
||||||
|
continue
|
||||||
|
# ignore vae
|
||||||
|
if ldm_key.startswith('first_stage_model'):
|
||||||
|
continue
|
||||||
|
ldm_key = ldm_key.replace('model.diffusion_model.', 'lora_unet_')
|
||||||
|
ldm_key = ldm_key.replace('.weight', '')
|
||||||
|
ldm_key = ldm_key.replace('.', '_')
|
||||||
|
|
||||||
|
diffusers_key = diffusers_key.replace('unet_', 'lora_unet_')
|
||||||
|
diffusers_key = diffusers_key.replace('.weight', '')
|
||||||
|
diffusers_key = diffusers_key.replace('.', '_')
|
||||||
|
|
||||||
|
lora_keymap[f"{ldm_key}.alpha"] = f"{diffusers_key}.alpha"
|
||||||
|
lora_keymap[f"{ldm_key}.lora_down.weight"] = f"{diffusers_key}.lora_down.weight"
|
||||||
|
lora_keymap[f"{ldm_key}.lora_up.weight"] = f"{diffusers_key}.lora_up.weight"
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("input", help="input file")
|
||||||
|
parser.add_argument("input2", help="input2 file")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# name = args.name
|
||||||
|
# if args.sdxl:
|
||||||
|
# name += '_sdxl'
|
||||||
|
# elif args.sd2:
|
||||||
|
# name += '_sd2'
|
||||||
|
# else:
|
||||||
|
# name += '_sd1'
|
||||||
|
name = 'stable_diffusion_locon_sdxl'
|
||||||
|
|
||||||
|
locon_save = load_file(args.input)
|
||||||
|
our_save = load_file(args.input2)
|
||||||
|
|
||||||
|
our_extra_keys = list(set(our_save.keys()) - set(locon_save.keys()))
|
||||||
|
locon_extra_keys = list(set(locon_save.keys()) - set(our_save.keys()))
|
||||||
|
|
||||||
|
print(f"we have {len(our_extra_keys)} extra keys")
|
||||||
|
print(f"locon has {len(locon_extra_keys)} extra keys")
|
||||||
|
|
||||||
|
save_dtype = torch.float16
|
||||||
|
print(f"our extra keys: {our_extra_keys}")
|
||||||
|
print(f"locon extra keys: {locon_extra_keys}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_state_dict(our_save):
|
||||||
|
converted_state_dict = OrderedDict()
|
||||||
|
for key, value in our_save.items():
|
||||||
|
# test encoders share keys for some reason
|
||||||
|
if key.startswith('lora_te'):
|
||||||
|
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
|
||||||
|
else:
|
||||||
|
converted_key = key
|
||||||
|
for ldm_key, diffusers_key in lora_keymap.items():
|
||||||
|
if converted_key == diffusers_key:
|
||||||
|
converted_key = ldm_key
|
||||||
|
|
||||||
|
converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
|
||||||
|
return converted_state_dict
|
||||||
|
|
||||||
|
def import_state_dict(loaded_state_dict):
|
||||||
|
converted_state_dict = OrderedDict()
|
||||||
|
for key, value in loaded_state_dict.items():
|
||||||
|
if key.startswith('lora_te'):
|
||||||
|
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
|
||||||
|
else:
|
||||||
|
converted_key = key
|
||||||
|
for ldm_key, diffusers_key in lora_keymap.items():
|
||||||
|
if converted_key == ldm_key:
|
||||||
|
converted_key = diffusers_key
|
||||||
|
|
||||||
|
converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
|
||||||
|
return converted_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
# check it again
|
||||||
|
converted_state_dict = export_state_dict(our_save)
|
||||||
|
converted_extra_keys = list(set(converted_state_dict.keys()) - set(locon_save.keys()))
|
||||||
|
locon_extra_keys = list(set(locon_save.keys()) - set(converted_state_dict.keys()))
|
||||||
|
|
||||||
|
|
||||||
|
print(f"we have {len(converted_extra_keys)} extra keys")
|
||||||
|
print(f"locon has {len(locon_extra_keys)} extra keys")
|
||||||
|
|
||||||
|
print(f"our extra keys: {converted_extra_keys}")
|
||||||
|
|
||||||
|
# convert back
|
||||||
|
cycle_state_dict = import_state_dict(converted_state_dict)
|
||||||
|
cycle_extra_keys = list(set(cycle_state_dict.keys()) - set(our_save.keys()))
|
||||||
|
our_extra_keys = list(set(our_save.keys()) - set(cycle_state_dict.keys()))
|
||||||
|
|
||||||
|
print(f"we have {len(our_extra_keys)} extra keys")
|
||||||
|
print(f"cycle has {len(cycle_extra_keys)} extra keys")
|
||||||
|
|
||||||
|
# save keymap
|
||||||
|
to_save = OrderedDict()
|
||||||
|
to_save['ldm_diffusers_keymap'] = lora_keymap
|
||||||
|
|
||||||
|
with open(os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', f'{name}.json'), 'w') as f:
|
||||||
|
json.dump(to_save, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
3154
toolkit/keymaps/stable_diffusion_locon_sdxl.json
Normal file
3154
toolkit/keymaps/stable_diffusion_locon_sdxl.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,15 @@
|
|||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import List, Optional, Dict, Type, Union
|
from typing import List, Optional, Dict, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
|
|
||||||
from .paths import SD_SCRIPTS_ROOT
|
from .paths import SD_SCRIPTS_ROOT, KEYMAPS_ROOT
|
||||||
from .train_tools import get_torch_dtype
|
from .train_tools import get_torch_dtype
|
||||||
|
|
||||||
sys.path.append(SD_SCRIPTS_ROOT)
|
sys.path.append(SD_SCRIPTS_ROOT)
|
||||||
@@ -268,6 +270,8 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
varbose: Optional[bool] = False,
|
varbose: Optional[bool] = False,
|
||||||
train_text_encoder: Optional[bool] = True,
|
train_text_encoder: Optional[bool] = True,
|
||||||
train_unet: Optional[bool] = True,
|
train_unet: Optional[bool] = True,
|
||||||
|
is_sdxl=False,
|
||||||
|
is_v2=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||||
@@ -293,6 +297,8 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
self._is_normalizing: bool = False
|
self._is_normalizing: bool = False
|
||||||
# triggers the state updates
|
# triggers the state updates
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
|
self.is_sdxl = is_sdxl
|
||||||
|
self.is_v2 = is_v2
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
print(f"create LoRA network from weights")
|
print(f"create LoRA network from weights")
|
||||||
@@ -440,23 +446,71 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
names.add(lora.lora_name)
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
|
def get_keymap(self):
|
||||||
|
if self.is_sdxl:
|
||||||
|
keymap_tail = 'sdxl'
|
||||||
|
elif self.is_v2:
|
||||||
|
keymap_tail = 'sd2'
|
||||||
|
else:
|
||||||
|
keymap_tail = 'sd1'
|
||||||
|
# load keymap
|
||||||
|
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
|
||||||
|
|
||||||
|
keymap = None
|
||||||
|
# check if file exists
|
||||||
|
if os.path.exists(keymap_name):
|
||||||
|
with open(keymap_name, 'r') as f:
|
||||||
|
keymap = json.load(f)
|
||||||
|
|
||||||
|
return keymap
|
||||||
|
|
||||||
def save_weights(self, file, dtype, metadata):
|
def save_weights(self, file, dtype, metadata):
|
||||||
|
keymap = self.get_keymap()
|
||||||
|
|
||||||
|
save_keymap = {}
|
||||||
|
if keymap is not None:
|
||||||
|
for ldm_key, diffusers_key in keymap.items():
|
||||||
|
# invert them
|
||||||
|
save_keymap[diffusers_key] = ldm_key
|
||||||
|
|
||||||
if metadata is not None and len(metadata) == 0:
|
if metadata is not None and len(metadata) == 0:
|
||||||
metadata = None
|
metadata = None
|
||||||
|
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
save_dict = OrderedDict()
|
||||||
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
v = state_dict[key]
|
v = state_dict[key]
|
||||||
v = v.detach().clone().to("cpu").to(dtype)
|
v = v.detach().clone().to("cpu").to(dtype)
|
||||||
state_dict[key] = v
|
save_key = save_keymap[key] if key in save_keymap else key
|
||||||
|
save_dict[save_key] = v
|
||||||
|
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
save_file(state_dict, file, metadata)
|
save_file(save_dict, file, metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(save_dict, file)
|
||||||
|
|
||||||
|
def load_weights(self, file):
|
||||||
|
# allows us to save and load to and from ldm weights
|
||||||
|
keymap = self.get_keymap()
|
||||||
|
keymap = {} if keymap is None else keymap
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
load_sd = OrderedDict()
|
||||||
|
for key, value in weights_sd.items():
|
||||||
|
load_key = keymap[key] if key in keymap else key
|
||||||
|
load_sd[load_key] = value
|
||||||
|
|
||||||
|
info = self.load_state_dict(load_sd, False)
|
||||||
|
return info
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multiplier(self) -> Union[float, List[float]]:
|
def multiplier(self) -> Union[float, List[float]]:
|
||||||
|
|||||||
@@ -141,3 +141,30 @@ def save_ldm_model_from_diffusers(
|
|||||||
# make sure parent folder exists
|
# make sure parent folder exists
|
||||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
save_file(converted_state_dict, output_file, metadata=meta)
|
save_file(converted_state_dict, output_file, metadata=meta)
|
||||||
|
|
||||||
|
|
||||||
|
def save_lora_from_diffusers(
|
||||||
|
lora_state_dict: 'OrderedDict',
|
||||||
|
output_file: str,
|
||||||
|
meta: 'OrderedDict',
|
||||||
|
save_dtype=get_torch_dtype('fp16'),
|
||||||
|
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
||||||
|
):
|
||||||
|
converted_state_dict = OrderedDict()
|
||||||
|
# only handle sxdxl for now
|
||||||
|
if sd_version != 'sdxl':
|
||||||
|
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||||
|
for key, value in lora_state_dict.items():
|
||||||
|
# test encoders share keys for some reason
|
||||||
|
if key.begins_with('lora_te'):
|
||||||
|
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
|
||||||
|
else:
|
||||||
|
converted_key = key
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# make sure parent folder exists
|
||||||
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
|
save_file(converted_state_dict, output_file, metadata=meta
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user