mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Corrected key saving and loading to better match kohya
This commit is contained in:
3154
toolkit/keymaps/stable_diffusion_locon_sdxl.json
Normal file
3154
toolkit/keymaps/stable_diffusion_locon_sdxl.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,15 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional, Dict, Type, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from .paths import SD_SCRIPTS_ROOT
|
||||
from .paths import SD_SCRIPTS_ROOT, KEYMAPS_ROOT
|
||||
from .train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
@@ -268,6 +270,8 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
varbose: Optional[bool] = False,
|
||||
train_text_encoder: Optional[bool] = True,
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
) -> None:
|
||||
"""
|
||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||
@@ -293,6 +297,8 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
self._is_normalizing: bool = False
|
||||
# triggers the state updates
|
||||
self.multiplier = multiplier
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
@@ -440,23 +446,71 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def get_keymap(self):
|
||||
if self.is_sdxl:
|
||||
keymap_tail = 'sdxl'
|
||||
elif self.is_v2:
|
||||
keymap_tail = 'sd2'
|
||||
else:
|
||||
keymap_tail = 'sd1'
|
||||
# load keymap
|
||||
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
|
||||
|
||||
keymap = None
|
||||
# check if file exists
|
||||
if os.path.exists(keymap_name):
|
||||
with open(keymap_name, 'r') as f:
|
||||
keymap = json.load(f)
|
||||
|
||||
return keymap
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
keymap = self.get_keymap()
|
||||
|
||||
save_keymap = {}
|
||||
if keymap is not None:
|
||||
for ldm_key, diffusers_key in keymap.items():
|
||||
# invert them
|
||||
save_keymap[diffusers_key] = ldm_key
|
||||
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
save_dict = OrderedDict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
save_key = save_keymap[key] if key in save_keymap else key
|
||||
save_dict[save_key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
save_file(state_dict, file, metadata)
|
||||
save_file(save_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
torch.save(save_dict, file)
|
||||
|
||||
def load_weights(self, file):
|
||||
# allows us to save and load to and from ldm weights
|
||||
keymap = self.get_keymap()
|
||||
keymap = {} if keymap is None else keymap
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
load_sd = OrderedDict()
|
||||
for key, value in weights_sd.items():
|
||||
load_key = keymap[key] if key in keymap else key
|
||||
load_sd[load_key] = value
|
||||
|
||||
info = self.load_state_dict(load_sd, False)
|
||||
return info
|
||||
|
||||
@property
|
||||
def multiplier(self) -> Union[float, List[float]]:
|
||||
|
||||
@@ -141,3 +141,30 @@ def save_ldm_model_from_diffusers(
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
|
||||
def save_lora_from_diffusers(
|
||||
lora_state_dict: 'OrderedDict',
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
save_dtype=get_torch_dtype('fp16'),
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
||||
):
|
||||
converted_state_dict = OrderedDict()
|
||||
# only handle sxdxl for now
|
||||
if sd_version != 'sdxl':
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
for key, value in lora_state_dict.items():
|
||||
# test encoders share keys for some reason
|
||||
if key.begins_with('lora_te'):
|
||||
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
|
||||
else:
|
||||
converted_key = key
|
||||
|
||||
|
||||
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user