Initial commit

This commit is contained in:
Jaret Burkett
2023-07-05 16:44:58 -06:00
commit e4de8983c9
11 changed files with 2185 additions and 0 deletions

171
.gitignore vendored Normal file
View File

@@ -0,0 +1,171 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
/env.sh
/models
/custom/*
!/custom/.gitkeep
/.tmp
/venv.bkp
/venv.*
/config/*
!/config/examples
!/config/_PUT_YOUR_CONFIGS_HERE).txt

52
README.md Normal file
View File

@@ -0,0 +1,52 @@
# AI Toolkit by Ostris
WIP for now, but will be a collection of tools for AI tools as I need them.
## Installation
I will try to update this to be more beginner-friendly, but for now I am assuming
a general understanding of python, pip, pytorch, and using virtual environments:
Linux:
```bash
pythion3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```
Windows:
```bash
pythion3 -m venv venv
venv\Scripts\activate
pip install -r requirements.txt
```
## Current Tools
### LyCORIS extractor
It is similar to the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features.
It all runs off a config file, which you can find an example of in `config/examples/locon_config.example.json`.
Just copy that file, into the `config` folder, and rename it to `whatever_you_want.json`.
Then you can edit the file to your liking. and call it like so:
```bash
python3 scripts/extract_locon.py "whatever_you_want"
```
You can also put a full path to a config file, if you want to keep it somewhere else.
```bash
python3 scripts/extract_locon.py "/home/user/whatever_you_want.json"
```
File name is auto generated and dumped into the `output` folder. You can put whatever meta you want in the
`meta` section of the config file, and it will be added to the metadata of the output file. I just have
some recommended fields in the example file. The script will add some other useful metadata as well.
process is an array or different processes to run on the conversion to test. You will normally just need one though.
Will update this later.

View File

@@ -0,0 +1,46 @@
{
"config": {
"name": "name_of_your_model",
"base_model": "/path/to/base/model",
"extract_model": "/path/to/model/to/extract",
"output_folder": "/path/to/output/folder",
"is_v2": false,
"device": "cpu",
"use_sparse_bias": false,
"sparsity": 0.98,
"disable_cp": false,
"process": [
{
"mode": "fixed",
"linear_dim": 64,
"conv_dim": 32
},
{
"mode": "ratio",
"linear_ratio": 0.2,
"conv_ratio": 0.2
},
{
"mode": "quantile",
"linear_quantile": 0.5,
"conv_quantile": 0.5
}
]
},
"meta": {
"name": "[name]",
"description": "A short description of your model",
"trigger_words": [
"put",
"trigger",
"words",
"here"
],
"version": "0.1",
"creator": {
"name": "Your Name",
"email": "your@email.com",
"website": "https://yourwebsite.com"
}
}
}

6
requirements.txt Normal file
View File

@@ -0,0 +1,6 @@
torch
safetensors
diffusers
transformers
lycoris_lora
flatten_json

151
scripts/extract_locon.py Normal file
View File

@@ -0,0 +1,151 @@
import json
import os
import sys
from flatten_json import flatten
sys.path.insert(0, os.getcwd())
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
CONFIG_FOLDER = os.path.join(PROJECT_ROOT, 'config')
sys.path.append(PROJECT_ROOT)
import argparse
from toolkit.lycoris_utils import extract_diff
from toolkit.config import get_config
from toolkit.metadata import create_meta
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
import torch
from safetensors.torch import save_file
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"config_file",
help="Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder",
type=str
)
return parser.parse_args()
def main():
args = get_args()
config_raw = get_config(args.config_file)
config = config_raw['config'] if 'config' in config_raw else None
if not config:
raise ValueError('config file is invalid. Missing "config" key')
meta = config_raw['meta'] if 'meta' in config_raw else {}
def get_conf(key, default=None):
if key in config:
return config[key]
else:
return default
is_v2 = get_conf('is_v2', False)
name = get_conf('name', None)
base_model = get_conf('base_model')
extract_model = get_conf('extract_model')
output_folder = get_conf('output_folder')
process_list = get_conf('process')
device = get_conf('device', 'cpu')
use_sparse_bias = get_conf('use_sparse_bias', False)
sparsity = get_conf('sparsity', 0.98)
disable_cp = get_conf('disable_cp', False)
if not name:
raise ValueError('name is required')
if not base_model:
raise ValueError('base_model is required')
if not extract_model:
raise ValueError('extract_model is required')
if not output_folder:
raise ValueError('output_folder is required')
if not process_list or len(process_list) == 0:
raise ValueError('process is required')
# check processes
for process in process_list:
if process['mode'] == 'fixed':
if not process['linear_dim']:
raise ValueError('linear_dim is required in fixed mode')
if not process['conv_dim']:
raise ValueError('conv_dim is required in fixed mode')
elif process['mode'] == 'threshold':
if not process['linear_threshold']:
raise ValueError('linear_threshold is required in threshold mode')
if not process['conv_threshold']:
raise ValueError('conv_threshold is required in threshold mode')
elif process['mode'] == 'ratio':
if not process['linear_ratio']:
raise ValueError('linear_ratio is required in ratio mode')
if not process['conv_ratio']:
raise ValueError('conv_threshold is required in threshold mode')
elif process['mode'] == 'quantile':
if not process['linear_quantile']:
raise ValueError('linear_quantile is required in quantile mode')
if not process['conv_quantile']:
raise ValueError('conv_quantile is required in quantile mode')
else:
raise ValueError('mode is invalid')
print(f"Loading base model: {base_model}")
base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
print(f"Loading extract model: {extract_model}")
extract = load_models_from_stable_diffusion_checkpoint(is_v2, extract_model)
print(f"Running {len(process_list)} process{'' if len(process_list) == 1 else 'es'}")
for process in process_list:
item_meta = json.loads(json.dumps(meta))
item_meta['process'] = process
if process['mode'] == 'fixed':
linear_mode_param = int(process['linear_dim'])
conv_mode_param = int(process['conv_dim'])
elif process['mode'] == 'threshold':
linear_mode_param = float(process['linear_threshold'])
conv_mode_param = float(process['conv_threshold'])
elif process['mode'] == 'ratio':
linear_mode_param = float(process['linear_ratio'])
conv_mode_param = float(process['conv_ratio'])
elif process['mode'] == 'quantile':
linear_mode_param = float(process['linear_quantile'])
conv_mode_param = float(process['conv_quantile'])
else:
raise ValueError(f"Unknown mode: {process['mode']}")
print(f"Running process: {process['mode']}, lin: {linear_mode_param}, conv: {conv_mode_param}")
state_dict, extract_diff_meta = extract_diff(
base,
extract,
process['mode'],
linear_mode_param,
conv_mode_param,
device,
use_sparse_bias,
sparsity,
not disable_cp
)
save_meta = create_meta([
item_meta, extract_diff_meta
])
output_file_name = f"lyco_{name}_{process['mode']}_{linear_mode_param}_{conv_mode_param}.safetensors"
output_path = os.path.join(output_folder, output_file_name)
os.makedirs(output_folder, exist_ok=True)
# having issues with meta
save_file(state_dict, output_path)
# save_file(state_dict, output_path, {'meta': json.dumps(save_meta, indent=4)})
print(f"Saved to {output_path}")
if __name__ == '__main__':
main()

0
toolkit/__init__.py Normal file
View File

39
toolkit/config.py Normal file
View File

@@ -0,0 +1,39 @@
import os
import json
from toolkit.paths import TOOLKIT_ROOT
possible_extensions = ['.json', '.jsonc']
def get_cwd_abs_path(path):
if not os.path.isabs(path):
path = os.path.join(os.getcwd(), path)
return path
def get_config(config_file_path):
# first check if it is in the config folder
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
# see if it is in the config folder with any of the possible extensions if it doesnt have one
real_config_path = None
if not os.path.exists(config_path):
for ext in possible_extensions:
if os.path.exists(config_path + ext):
real_config_path = config_path + ext
break
# if we didn't find it there, check if it is a full path
if not real_config_path:
if os.path.exists(config_file_path):
real_config_path = config_file_path
elif os.path.exists(get_cwd_abs_path(config_file_path)):
real_config_path = get_cwd_abs_path(config_file_path)
if not real_config_path:
raise ValueError(f"Could not find config file {config_file_path}")
# load the config
with open(real_config_path, 'r') as f:
config = json.load(f)
return config

1180
toolkit/kohya_model_util.py Normal file

File diff suppressed because it is too large Load Diff

512
toolkit/lycoris_utils.py Normal file
View File

@@ -0,0 +1,512 @@
# heavily based on https://github.com/KohakuBlueleaf/LyCORIS/blob/main/lycoris/utils.py
from typing import *
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.linalg as linalg
from tqdm import tqdm
def make_sparse(t: torch.Tensor, sparsity=0.95):
abs_t = torch.abs(t)
np_array = abs_t.detach().cpu().numpy()
quan = float(np.quantile(np_array, sparsity))
sparse_t = t.masked_fill(abs_t < quan, 0)
return sparse_t
def extract_conv(
weight: Union[torch.Tensor, nn.Parameter],
mode='fixed',
mode_param=0,
device='cpu',
is_cp=False,
) -> Tuple[nn.Parameter, nn.Parameter]:
weight = weight.to(device)
out_ch, in_ch, kernel_size, _ = weight.shape
U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
if mode == 'fixed':
lora_rank = mode_param
elif mode == 'threshold':
assert mode_param >= 0
lora_rank = torch.sum(S > mode_param)
elif mode == 'ratio':
assert 1 >= mode_param >= 0
min_s = torch.max(S) * mode_param
lora_rank = torch.sum(S > min_s)
elif mode == 'quantile' or mode == 'percentile':
assert 1 >= mode_param >= 0
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = mode_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum)
else:
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
lora_rank = max(1, lora_rank)
lora_rank = min(out_ch, in_ch, lora_rank)
if lora_rank >= out_ch / 2 and not is_cp:
return weight, 'full'
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
del U, S, Vh, weight
return (extract_weight_A, extract_weight_B, diff), 'low rank'
def extract_linear(
weight: Union[torch.Tensor, nn.Parameter],
mode='fixed',
mode_param=0,
device='cpu',
) -> Tuple[nn.Parameter, nn.Parameter]:
weight = weight.to(device)
out_ch, in_ch = weight.shape
U, S, Vh = linalg.svd(weight)
if mode == 'fixed':
lora_rank = mode_param
elif mode == 'threshold':
assert mode_param >= 0
lora_rank = torch.sum(S > mode_param)
elif mode == 'ratio':
assert 1 >= mode_param >= 0
min_s = torch.max(S) * mode_param
lora_rank = torch.sum(S > min_s)
elif mode == 'quantile' or mode == 'percentile':
assert 1 >= mode_param >= 0
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = mode_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum)
else:
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
lora_rank = max(1, lora_rank)
lora_rank = min(out_ch, in_ch, lora_rank)
if lora_rank >= out_ch / 2:
return weight, 'full'
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
diff = (weight - U @ Vh).detach()
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
del U, S, Vh, weight
return (extract_weight_A, extract_weight_B, diff), 'low rank'
def extract_diff(
base_model,
db_model,
mode='fixed',
linear_mode_param=0,
conv_mode_param=0,
extract_device='cpu',
use_bias=False,
sparsity=0.98,
small_conv=True
):
meta = {}
UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",
"Attention",
"ResnetBlock2D",
"Downsample2D",
"Upsample2D"
]
UNET_TARGET_REPLACE_NAME = [
"conv_in",
"conv_out",
"time_embedding.linear_1",
"time_embedding.linear_2",
]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
def make_state_dict(
prefix,
root_module: torch.nn.Module,
target_module: torch.nn.Module,
target_replace_modules,
target_replace_names=[]
):
loras = {}
temp = {}
temp_name = {}
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
temp[name] = {}
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
continue
temp[name][child_name] = child_module.weight
elif name in target_replace_names:
temp_name[name] = module.weight
for name, module in tqdm(list(target_module.named_modules())):
if name in temp:
weights = temp[name]
for child_name, child_module in module.named_modules():
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
layer = child_module.__class__.__name__
if layer in {'Linear', 'Conv2d'}:
root_weight = child_module.weight
if torch.allclose(root_weight, weights[child_name]):
continue
if layer == 'Linear':
weight, decompose_mode = extract_linear(
(child_module.weight - weights[child_name]),
mode,
linear_mode_param,
device=extract_device,
)
if decompose_mode == 'low rank':
extract_a, extract_b, diff = weight
elif layer == 'Conv2d':
is_linear = (child_module.weight.shape[2] == 1
and child_module.weight.shape[3] == 1)
weight, decompose_mode = extract_conv(
(child_module.weight - weights[child_name]),
mode,
linear_mode_param if is_linear else conv_mode_param,
device=extract_device,
)
if decompose_mode == 'low rank':
extract_a, extract_b, diff = weight
if small_conv and not is_linear and decompose_mode == 'low rank':
dim = extract_a.size(0)
(extract_c, extract_a, _), _ = extract_conv(
extract_a.transpose(0, 1),
'fixed', dim,
extract_device, True
)
extract_a = extract_a.transpose(0, 1)
extract_c = extract_c.transpose(0, 1)
loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
diff = child_module.weight - torch.einsum(
'i j k l, j r, p i -> p r k l',
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
).detach().cpu().contiguous()
del extract_c
else:
continue
if decompose_mode == 'low rank':
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
if use_bias:
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
indices = sparse_diff.indices().to(torch.int16)
values = sparse_diff.values().half()
loras[f'{lora_name}.bias_indices'] = indices
loras[f'{lora_name}.bias_values'] = values
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
del extract_a, extract_b, diff
elif decompose_mode == 'full':
loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
else:
raise NotImplementedError
elif name in temp_name:
weights = temp_name[name]
lora_name = prefix + '.' + name
lora_name = lora_name.replace('.', '_')
layer = module.__class__.__name__
if layer in {'Linear', 'Conv2d'}:
root_weight = module.weight
if torch.allclose(root_weight, weights):
continue
if layer == 'Linear':
weight, decompose_mode = extract_linear(
(root_weight - weights),
mode,
linear_mode_param,
device=extract_device,
)
if decompose_mode == 'low rank':
extract_a, extract_b, diff = weight
elif layer == 'Conv2d':
is_linear = (
root_weight.shape[2] == 1
and root_weight.shape[3] == 1
)
weight, decompose_mode = extract_conv(
(root_weight - weights),
mode,
linear_mode_param if is_linear else conv_mode_param,
device=extract_device,
)
if decompose_mode == 'low rank':
extract_a, extract_b, diff = weight
if small_conv and not is_linear and decompose_mode == 'low rank':
dim = extract_a.size(0)
(extract_c, extract_a, _), _ = extract_conv(
extract_a.transpose(0, 1),
'fixed', dim,
extract_device, True
)
extract_a = extract_a.transpose(0, 1)
extract_c = extract_c.transpose(0, 1)
loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
diff = root_weight - torch.einsum(
'i j k l, j r, p i -> p r k l',
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
).detach().cpu().contiguous()
del extract_c
else:
continue
if decompose_mode == 'low rank':
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
if use_bias:
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
indices = sparse_diff.indices().to(torch.int16)
values = sparse_diff.values().half()
loras[f'{lora_name}.bias_indices'] = indices
loras[f'{lora_name}.bias_values'] = values
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
del extract_a, extract_b, diff
elif decompose_mode == 'full':
loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
else:
raise NotImplementedError
return loras
text_encoder_loras = make_state_dict(
LORA_PREFIX_TEXT_ENCODER,
base_model[0], db_model[0],
TEXT_ENCODER_TARGET_REPLACE_MODULE
)
unet_loras = make_state_dict(
LORA_PREFIX_UNET,
base_model[2], db_model[2],
UNET_TARGET_REPLACE_MODULE,
UNET_TARGET_REPLACE_NAME
)
print(len(text_encoder_loras), len(unet_loras))
# the | will
return (text_encoder_loras | unet_loras), meta
def get_module(
lyco_state_dict: Dict,
lora_name
):
if f'{lora_name}.lora_up.weight' in lyco_state_dict:
up = lyco_state_dict[f'{lora_name}.lora_up.weight']
down = lyco_state_dict[f'{lora_name}.lora_down.weight']
mid = lyco_state_dict.get(f'{lora_name}.lora_mid.weight', None)
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
return 'locon', (up, down, mid, alpha)
elif f'{lora_name}.hada_w1_a' in lyco_state_dict:
w1a = lyco_state_dict[f'{lora_name}.hada_w1_a']
w1b = lyco_state_dict[f'{lora_name}.hada_w1_b']
w2a = lyco_state_dict[f'{lora_name}.hada_w2_a']
w2b = lyco_state_dict[f'{lora_name}.hada_w2_b']
t1 = lyco_state_dict.get(f'{lora_name}.hada_t1', None)
t2 = lyco_state_dict.get(f'{lora_name}.hada_t2', None)
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
return 'hada', (w1a, w1b, w2a, w2b, t1, t2, alpha)
elif f'{lora_name}.weight' in lyco_state_dict:
weight = lyco_state_dict[f'{lora_name}.weight']
on_input = lyco_state_dict.get(f'{lora_name}.on_input', False)
return 'ia3', (weight, on_input)
elif (f'{lora_name}.lokr_w1' in lyco_state_dict
or f'{lora_name}.lokr_w1_a' in lyco_state_dict):
w1 = lyco_state_dict.get(f'{lora_name}.lokr_w1', None)
w1a = lyco_state_dict.get(f'{lora_name}.lokr_w1_a', None)
w1b = lyco_state_dict.get(f'{lora_name}.lokr_w1_b', None)
w2 = lyco_state_dict.get(f'{lora_name}.lokr_w2', None)
w2a = lyco_state_dict.get(f'{lora_name}.lokr_w2_a', None)
w2b = lyco_state_dict.get(f'{lora_name}.lokr_w2_b', None)
t1 = lyco_state_dict.get(f'{lora_name}.lokr_t1', None)
t2 = lyco_state_dict.get(f'{lora_name}.lokr_t2', None)
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
return 'kron', (w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha)
elif f'{lora_name}.diff' in lyco_state_dict:
return 'full', lyco_state_dict[f'{lora_name}.diff']
else:
return 'None', ()
def cp_weight_from_conv(
up, down, mid
):
up = up.reshape(up.size(0), up.size(1))
down = down.reshape(down.size(0), down.size(1))
return torch.einsum('m n w h, i m, n j -> i j w h', mid, up, down)
def cp_weight(
wa, wb, t
):
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
@torch.no_grad()
def rebuild_weight(module_type, params, orig_weight, scale=1):
if orig_weight is None:
return orig_weight
merged = orig_weight
if module_type == 'locon':
up, down, mid, alpha = params
if alpha is not None:
scale *= alpha / up.size(1)
if mid is not None:
rebuild = cp_weight_from_conv(up, down, mid)
else:
rebuild = up.reshape(up.size(0), -1) @ down.reshape(down.size(0), -1)
merged = orig_weight + rebuild.reshape(orig_weight.shape) * scale
del up, down, mid, alpha, params, rebuild
elif module_type == 'hada':
w1a, w1b, w2a, w2b, t1, t2, alpha = params
if alpha is not None:
scale *= alpha / w1b.size(0)
if t1 is not None:
rebuild1 = cp_weight(w1a, w1b, t1)
else:
rebuild1 = w1a @ w1b
if t2 is not None:
rebuild2 = cp_weight(w2a, w2b, t2)
else:
rebuild2 = w2a @ w2b
rebuild = (rebuild1 * rebuild2).reshape(orig_weight.shape)
merged = orig_weight + rebuild * scale
del w1a, w1b, w2a, w2b, t1, t2, alpha, params, rebuild, rebuild1, rebuild2
elif module_type == 'ia3':
weight, on_input = params
if not on_input:
weight = weight.reshape(-1, 1)
merged = orig_weight + weight * orig_weight * scale
del weight, on_input, params
elif module_type == 'kron':
w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha = params
if alpha is not None and (w1b is not None or w2b is not None):
scale *= alpha / (w1b.size(0) if w1b else w2b.size(0))
if w1a is not None and w1b is not None:
if t1:
w1 = cp_weight(w1a, w1b, t1)
else:
w1 = w1a @ w1b
if w2a is not None and w2b is not None:
if t2:
w2 = cp_weight(w2a, w2b, t2)
else:
w2 = w2a @ w2b
rebuild = torch.kron(w1, w2).reshape(orig_weight.shape)
merged = orig_weight + rebuild * scale
del w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha, params, rebuild
elif module_type == 'full':
rebuild = params.reshape(orig_weight.shape)
merged = orig_weight + rebuild * scale
del params, rebuild
return merged
def merge(
base_model,
lyco_state_dict,
scale: float = 1.0,
device='cpu'
):
UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",
"Attention",
"ResnetBlock2D",
"Downsample2D",
"Upsample2D"
]
UNET_TARGET_REPLACE_NAME = [
"conv_in",
"conv_out",
"time_embedding.linear_1",
"time_embedding.linear_2",
]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
merged = 0
def merge_state_dict(
prefix,
root_module: torch.nn.Module,
lyco_state_dict: Dict[str, torch.Tensor],
target_replace_modules,
target_replace_names=[]
):
nonlocal merged
for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'):
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
continue
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
result = rebuild_weight(*get_module(
lyco_state_dict, lora_name
), getattr(child_module, 'weight'), scale)
if result is not None:
merged += 1
child_module.requires_grad_(False)
child_module.weight.copy_(result)
elif name in target_replace_names:
lora_name = prefix + '.' + name
lora_name = lora_name.replace('.', '_')
result = rebuild_weight(*get_module(
lyco_state_dict, lora_name
), getattr(module, 'weight'), scale)
if result is not None:
merged += 1
module.requires_grad_(False)
module.weight.copy_(result)
if device == 'cpu':
for k, v in tqdm(list(lyco_state_dict.items()), desc='Converting Dtype'):
lyco_state_dict[k] = v.float()
merge_state_dict(
LORA_PREFIX_TEXT_ENCODER,
base_model[0],
lyco_state_dict,
TEXT_ENCODER_TARGET_REPLACE_MODULE,
UNET_TARGET_REPLACE_NAME
)
merge_state_dict(
LORA_PREFIX_UNET,
base_model[2],
lyco_state_dict,
UNET_TARGET_REPLACE_MODULE,
UNET_TARGET_REPLACE_NAME
)
print(f'{merged} Modules been merged')

24
toolkit/metadata.py Normal file
View File

@@ -0,0 +1,24 @@
import json
software_meta = {
"name": "ai-toolkit",
"url": "https://github.com/ostris/ai-toolkit"
}
def create_meta(dict_list, name=None):
meta = {}
for d in dict_list:
for key, value in d.items():
meta[key] = value
if "name" not in meta:
meta["name"] = "[name]"
meta["software"] = software_meta
# convert to string to handle replacements
meta_string = json.dumps(meta)
if name is not None:
meta_string = meta_string.replace("[name]", name)
return json.loads(meta_string)

4
toolkit/paths.py Normal file
View File

@@ -0,0 +1,4 @@
import os
TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')