Corrected key saving and loading to better match kohya

This commit is contained in:
Jaret Burkett
2023-09-04 00:22:34 -06:00
parent 22ed539321
commit fa8fc32c0a
5 changed files with 3371 additions and 4 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,15 @@
import json
import math
import os
import re
import sys
from collections import OrderedDict
from typing import List, Optional, Dict, Type, Union
import torch
from transformers import CLIPTextModel
from .paths import SD_SCRIPTS_ROOT
from .paths import SD_SCRIPTS_ROOT, KEYMAPS_ROOT
from .train_tools import get_torch_dtype
sys.path.append(SD_SCRIPTS_ROOT)
@@ -268,6 +270,8 @@ class LoRASpecialNetwork(LoRANetwork):
varbose: Optional[bool] = False,
train_text_encoder: Optional[bool] = True,
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -293,6 +297,8 @@ class LoRASpecialNetwork(LoRANetwork):
self._is_normalizing: bool = False
# triggers the state updates
self.multiplier = multiplier
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
if modules_dim is not None:
print(f"create LoRA network from weights")
@@ -440,23 +446,71 @@ class LoRASpecialNetwork(LoRANetwork):
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def get_keymap(self):
if self.is_sdxl:
keymap_tail = 'sdxl'
elif self.is_v2:
keymap_tail = 'sd2'
else:
keymap_tail = 'sd1'
# load keymap
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
keymap = None
# check if file exists
if os.path.exists(keymap_name):
with open(keymap_name, 'r') as f:
keymap = json.load(f)
return keymap
def save_weights(self, file, dtype, metadata):
keymap = self.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():
# invert them
save_keymap[diffusers_key] = ldm_key
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
save_dict = OrderedDict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
save_key = save_keymap[key] if key in save_keymap else key
save_dict[save_key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file, metadata)
save_file(save_dict, file, metadata)
else:
torch.save(state_dict, file)
torch.save(save_dict, file)
def load_weights(self, file):
# allows us to save and load to and from ldm weights
keymap = self.get_keymap()
keymap = {} if keymap is None else keymap
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
load_sd = OrderedDict()
for key, value in weights_sd.items():
load_key = keymap[key] if key in keymap else key
load_sd[load_key] = value
info = self.load_state_dict(load_sd, False)
return info
@property
def multiplier(self) -> Union[float, List[float]]:

View File

@@ -141,3 +141,30 @@ def save_ldm_model_from_diffusers(
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
def save_lora_from_diffusers(
lora_state_dict: 'OrderedDict',
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
):
converted_state_dict = OrderedDict()
# only handle sxdxl for now
if sd_version != 'sdxl':
raise ValueError(f"Invalid sd_version {sd_version}")
for key, value in lora_state_dict.items():
# test encoders share keys for some reason
if key.begins_with('lora_te'):
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
else:
converted_key = key
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta
)