Corrected key saving and loading to better match kohya

This commit is contained in:
Jaret Burkett
2023-09-04 00:22:34 -06:00
parent 22ed539321
commit fa8fc32c0a
5 changed files with 3371 additions and 4 deletions

View File

@@ -452,6 +452,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_text_encoder=self.train_config.train_text_encoder, train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=self.network_config.conv, conv_lora_dim=self.network_config.conv,
conv_alpha=self.network_config.conv_alpha, conv_alpha=self.network_config.conv_alpha,
is_sdxl=self.model_config.is_xl,
is_v2=self.model_config.is_v2,
) )
self.network.force_to(self.device_torch, dtype=dtype) self.network.force_to(self.device_torch, dtype=dtype)

View File

@@ -0,0 +1,130 @@
from collections import OrderedDict
import torch
from safetensors.torch import load_file
import argparse
import os
import json
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
keymap_path = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', 'stable_diffusion_sdxl.json')
# load keymap
with open(keymap_path, 'r') as f:
keymap = json.load(f)
lora_keymap = OrderedDict()
# convert keymap to lora key naming
for ldm_key, diffusers_key in keymap['ldm_diffusers_keymap'].items():
if ldm_key.endswith('.bias') or diffusers_key.endswith('.bias'):
# skip it
continue
# sdxl has same te for locon with kohya and ours
if ldm_key.startswith('conditioner'):
#skip it
continue
# ignore vae
if ldm_key.startswith('first_stage_model'):
continue
ldm_key = ldm_key.replace('model.diffusion_model.', 'lora_unet_')
ldm_key = ldm_key.replace('.weight', '')
ldm_key = ldm_key.replace('.', '_')
diffusers_key = diffusers_key.replace('unet_', 'lora_unet_')
diffusers_key = diffusers_key.replace('.weight', '')
diffusers_key = diffusers_key.replace('.', '_')
lora_keymap[f"{ldm_key}.alpha"] = f"{diffusers_key}.alpha"
lora_keymap[f"{ldm_key}.lora_down.weight"] = f"{diffusers_key}.lora_down.weight"
lora_keymap[f"{ldm_key}.lora_up.weight"] = f"{diffusers_key}.lora_up.weight"
parser = argparse.ArgumentParser()
parser.add_argument("input", help="input file")
parser.add_argument("input2", help="input2 file")
args = parser.parse_args()
# name = args.name
# if args.sdxl:
# name += '_sdxl'
# elif args.sd2:
# name += '_sd2'
# else:
# name += '_sd1'
name = 'stable_diffusion_locon_sdxl'
locon_save = load_file(args.input)
our_save = load_file(args.input2)
our_extra_keys = list(set(our_save.keys()) - set(locon_save.keys()))
locon_extra_keys = list(set(locon_save.keys()) - set(our_save.keys()))
print(f"we have {len(our_extra_keys)} extra keys")
print(f"locon has {len(locon_extra_keys)} extra keys")
save_dtype = torch.float16
print(f"our extra keys: {our_extra_keys}")
print(f"locon extra keys: {locon_extra_keys}")
def export_state_dict(our_save):
converted_state_dict = OrderedDict()
for key, value in our_save.items():
# test encoders share keys for some reason
if key.startswith('lora_te'):
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
else:
converted_key = key
for ldm_key, diffusers_key in lora_keymap.items():
if converted_key == diffusers_key:
converted_key = ldm_key
converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
return converted_state_dict
def import_state_dict(loaded_state_dict):
converted_state_dict = OrderedDict()
for key, value in loaded_state_dict.items():
if key.startswith('lora_te'):
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
else:
converted_key = key
for ldm_key, diffusers_key in lora_keymap.items():
if converted_key == ldm_key:
converted_key = diffusers_key
converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
return converted_state_dict
# check it again
converted_state_dict = export_state_dict(our_save)
converted_extra_keys = list(set(converted_state_dict.keys()) - set(locon_save.keys()))
locon_extra_keys = list(set(locon_save.keys()) - set(converted_state_dict.keys()))
print(f"we have {len(converted_extra_keys)} extra keys")
print(f"locon has {len(locon_extra_keys)} extra keys")
print(f"our extra keys: {converted_extra_keys}")
# convert back
cycle_state_dict = import_state_dict(converted_state_dict)
cycle_extra_keys = list(set(cycle_state_dict.keys()) - set(our_save.keys()))
our_extra_keys = list(set(our_save.keys()) - set(cycle_state_dict.keys()))
print(f"we have {len(our_extra_keys)} extra keys")
print(f"cycle has {len(cycle_extra_keys)} extra keys")
# save keymap
to_save = OrderedDict()
to_save['ldm_diffusers_keymap'] = lora_keymap
with open(os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', f'{name}.json'), 'w') as f:
json.dump(to_save, f, indent=4)

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,15 @@
import json
import math import math
import os import os
import re import re
import sys import sys
from collections import OrderedDict
from typing import List, Optional, Dict, Type, Union from typing import List, Optional, Dict, Type, Union
import torch import torch
from transformers import CLIPTextModel from transformers import CLIPTextModel
from .paths import SD_SCRIPTS_ROOT from .paths import SD_SCRIPTS_ROOT, KEYMAPS_ROOT
from .train_tools import get_torch_dtype from .train_tools import get_torch_dtype
sys.path.append(SD_SCRIPTS_ROOT) sys.path.append(SD_SCRIPTS_ROOT)
@@ -268,6 +270,8 @@ class LoRASpecialNetwork(LoRANetwork):
varbose: Optional[bool] = False, varbose: Optional[bool] = False,
train_text_encoder: Optional[bool] = True, train_text_encoder: Optional[bool] = True,
train_unet: Optional[bool] = True, train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
) -> None: ) -> None:
""" """
LoRA network: すごく引数が多いが、パターンは以下の通り LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -293,6 +297,8 @@ class LoRASpecialNetwork(LoRANetwork):
self._is_normalizing: bool = False self._is_normalizing: bool = False
# triggers the state updates # triggers the state updates
self.multiplier = multiplier self.multiplier = multiplier
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
if modules_dim is not None: if modules_dim is not None:
print(f"create LoRA network from weights") print(f"create LoRA network from weights")
@@ -440,23 +446,71 @@ class LoRASpecialNetwork(LoRANetwork):
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name) names.add(lora.lora_name)
def get_keymap(self):
if self.is_sdxl:
keymap_tail = 'sdxl'
elif self.is_v2:
keymap_tail = 'sd2'
else:
keymap_tail = 'sd1'
# load keymap
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
keymap = None
# check if file exists
if os.path.exists(keymap_name):
with open(keymap_name, 'r') as f:
keymap = json.load(f)
return keymap
def save_weights(self, file, dtype, metadata): def save_weights(self, file, dtype, metadata):
keymap = self.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():
# invert them
save_keymap[diffusers_key] = ldm_key
if metadata is not None and len(metadata) == 0: if metadata is not None and len(metadata) == 0:
metadata = None metadata = None
state_dict = self.state_dict() state_dict = self.state_dict()
save_dict = OrderedDict()
if dtype is not None: if dtype is not None:
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
v = state_dict[key] v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype) v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v save_key = save_keymap[key] if key in save_keymap else key
save_dict[save_key] = v
if os.path.splitext(file)[1] == ".safetensors": if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file from safetensors.torch import save_file
save_file(state_dict, file, metadata) save_file(save_dict, file, metadata)
else: else:
torch.save(state_dict, file) torch.save(save_dict, file)
def load_weights(self, file):
# allows us to save and load to and from ldm weights
keymap = self.get_keymap()
keymap = {} if keymap is None else keymap
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
load_sd = OrderedDict()
for key, value in weights_sd.items():
load_key = keymap[key] if key in keymap else key
load_sd[load_key] = value
info = self.load_state_dict(load_sd, False)
return info
@property @property
def multiplier(self) -> Union[float, List[float]]: def multiplier(self) -> Union[float, List[float]]:

View File

@@ -141,3 +141,30 @@ def save_ldm_model_from_diffusers(
# make sure parent folder exists # make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True) os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta) save_file(converted_state_dict, output_file, metadata=meta)
def save_lora_from_diffusers(
lora_state_dict: 'OrderedDict',
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
):
converted_state_dict = OrderedDict()
# only handle sxdxl for now
if sd_version != 'sdxl':
raise ValueError(f"Invalid sd_version {sd_version}")
for key, value in lora_state_dict.items():
# test encoders share keys for some reason
if key.begins_with('lora_te'):
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
else:
converted_key = key
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta
)