mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Initial commit
This commit is contained in:
171
.gitignore
vendored
Normal file
171
.gitignore
vendored
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
/env.sh
|
||||||
|
/models
|
||||||
|
/custom/*
|
||||||
|
!/custom/.gitkeep
|
||||||
|
/.tmp
|
||||||
|
/venv.bkp
|
||||||
|
/venv.*
|
||||||
|
/config/*
|
||||||
|
!/config/examples
|
||||||
|
!/config/_PUT_YOUR_CONFIGS_HERE).txt
|
||||||
52
README.md
Normal file
52
README.md
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# AI Toolkit by Ostris
|
||||||
|
|
||||||
|
WIP for now, but will be a collection of tools for AI tools as I need them.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
I will try to update this to be more beginner-friendly, but for now I am assuming
|
||||||
|
a general understanding of python, pip, pytorch, and using virtual environments:
|
||||||
|
|
||||||
|
Linux:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pythion3 -m venv venv
|
||||||
|
source venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Windows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pythion3 -m venv venv
|
||||||
|
venv\Scripts\activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Current Tools
|
||||||
|
|
||||||
|
### LyCORIS extractor
|
||||||
|
|
||||||
|
It is similar to the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features.
|
||||||
|
It all runs off a config file, which you can find an example of in `config/examples/locon_config.example.json`.
|
||||||
|
Just copy that file, into the `config` folder, and rename it to `whatever_you_want.json`.
|
||||||
|
Then you can edit the file to your liking. and call it like so:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 scripts/extract_locon.py "whatever_you_want"
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also put a full path to a config file, if you want to keep it somewhere else.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 scripts/extract_locon.py "/home/user/whatever_you_want.json"
|
||||||
|
```
|
||||||
|
|
||||||
|
File name is auto generated and dumped into the `output` folder. You can put whatever meta you want in the
|
||||||
|
`meta` section of the config file, and it will be added to the metadata of the output file. I just have
|
||||||
|
some recommended fields in the example file. The script will add some other useful metadata as well.
|
||||||
|
|
||||||
|
process is an array or different processes to run on the conversion to test. You will normally just need one though.
|
||||||
|
|
||||||
|
Will update this later.
|
||||||
|
|
||||||
46
config/examples/locon_config.example.json
Normal file
46
config/examples/locon_config.example.json
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"name": "name_of_your_model",
|
||||||
|
"base_model": "/path/to/base/model",
|
||||||
|
"extract_model": "/path/to/model/to/extract",
|
||||||
|
"output_folder": "/path/to/output/folder",
|
||||||
|
"is_v2": false,
|
||||||
|
"device": "cpu",
|
||||||
|
"use_sparse_bias": false,
|
||||||
|
"sparsity": 0.98,
|
||||||
|
"disable_cp": false,
|
||||||
|
"process": [
|
||||||
|
{
|
||||||
|
"mode": "fixed",
|
||||||
|
"linear_dim": 64,
|
||||||
|
"conv_dim": 32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"mode": "ratio",
|
||||||
|
"linear_ratio": 0.2,
|
||||||
|
"conv_ratio": 0.2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"mode": "quantile",
|
||||||
|
"linear_quantile": 0.5,
|
||||||
|
"conv_quantile": 0.5
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"meta": {
|
||||||
|
"name": "[name]",
|
||||||
|
"description": "A short description of your model",
|
||||||
|
"trigger_words": [
|
||||||
|
"put",
|
||||||
|
"trigger",
|
||||||
|
"words",
|
||||||
|
"here"
|
||||||
|
],
|
||||||
|
"version": "0.1",
|
||||||
|
"creator": {
|
||||||
|
"name": "Your Name",
|
||||||
|
"email": "your@email.com",
|
||||||
|
"website": "https://yourwebsite.com"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
6
requirements.txt
Normal file
6
requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
torch
|
||||||
|
safetensors
|
||||||
|
diffusers
|
||||||
|
transformers
|
||||||
|
lycoris_lora
|
||||||
|
flatten_json
|
||||||
151
scripts/extract_locon.py
Normal file
151
scripts/extract_locon.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from flatten_json import flatten
|
||||||
|
|
||||||
|
sys.path.insert(0, os.getcwd())
|
||||||
|
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
CONFIG_FOLDER = os.path.join(PROJECT_ROOT, 'config')
|
||||||
|
sys.path.append(PROJECT_ROOT)
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from toolkit.lycoris_utils import extract_diff
|
||||||
|
from toolkit.config import get_config
|
||||||
|
from toolkit.metadata import create_meta
|
||||||
|
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"config_file",
|
||||||
|
help="Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder",
|
||||||
|
type=str
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
config_raw = get_config(args.config_file)
|
||||||
|
config = config_raw['config'] if 'config' in config_raw else None
|
||||||
|
if not config:
|
||||||
|
raise ValueError('config file is invalid. Missing "config" key')
|
||||||
|
|
||||||
|
meta = config_raw['meta'] if 'meta' in config_raw else {}
|
||||||
|
|
||||||
|
def get_conf(key, default=None):
|
||||||
|
if key in config:
|
||||||
|
return config[key]
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
|
||||||
|
is_v2 = get_conf('is_v2', False)
|
||||||
|
name = get_conf('name', None)
|
||||||
|
base_model = get_conf('base_model')
|
||||||
|
extract_model = get_conf('extract_model')
|
||||||
|
output_folder = get_conf('output_folder')
|
||||||
|
process_list = get_conf('process')
|
||||||
|
device = get_conf('device', 'cpu')
|
||||||
|
use_sparse_bias = get_conf('use_sparse_bias', False)
|
||||||
|
sparsity = get_conf('sparsity', 0.98)
|
||||||
|
disable_cp = get_conf('disable_cp', False)
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
raise ValueError('name is required')
|
||||||
|
if not base_model:
|
||||||
|
raise ValueError('base_model is required')
|
||||||
|
if not extract_model:
|
||||||
|
raise ValueError('extract_model is required')
|
||||||
|
if not output_folder:
|
||||||
|
raise ValueError('output_folder is required')
|
||||||
|
if not process_list or len(process_list) == 0:
|
||||||
|
raise ValueError('process is required')
|
||||||
|
|
||||||
|
# check processes
|
||||||
|
for process in process_list:
|
||||||
|
if process['mode'] == 'fixed':
|
||||||
|
if not process['linear_dim']:
|
||||||
|
raise ValueError('linear_dim is required in fixed mode')
|
||||||
|
if not process['conv_dim']:
|
||||||
|
raise ValueError('conv_dim is required in fixed mode')
|
||||||
|
elif process['mode'] == 'threshold':
|
||||||
|
if not process['linear_threshold']:
|
||||||
|
raise ValueError('linear_threshold is required in threshold mode')
|
||||||
|
if not process['conv_threshold']:
|
||||||
|
raise ValueError('conv_threshold is required in threshold mode')
|
||||||
|
elif process['mode'] == 'ratio':
|
||||||
|
if not process['linear_ratio']:
|
||||||
|
raise ValueError('linear_ratio is required in ratio mode')
|
||||||
|
if not process['conv_ratio']:
|
||||||
|
raise ValueError('conv_threshold is required in threshold mode')
|
||||||
|
elif process['mode'] == 'quantile':
|
||||||
|
if not process['linear_quantile']:
|
||||||
|
raise ValueError('linear_quantile is required in quantile mode')
|
||||||
|
if not process['conv_quantile']:
|
||||||
|
raise ValueError('conv_quantile is required in quantile mode')
|
||||||
|
else:
|
||||||
|
raise ValueError('mode is invalid')
|
||||||
|
|
||||||
|
print(f"Loading base model: {base_model}")
|
||||||
|
base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
|
||||||
|
print(f"Loading extract model: {extract_model}")
|
||||||
|
extract = load_models_from_stable_diffusion_checkpoint(is_v2, extract_model)
|
||||||
|
|
||||||
|
print(f"Running {len(process_list)} process{'' if len(process_list) == 1 else 'es'}")
|
||||||
|
|
||||||
|
for process in process_list:
|
||||||
|
item_meta = json.loads(json.dumps(meta))
|
||||||
|
item_meta['process'] = process
|
||||||
|
if process['mode'] == 'fixed':
|
||||||
|
linear_mode_param = int(process['linear_dim'])
|
||||||
|
conv_mode_param = int(process['conv_dim'])
|
||||||
|
elif process['mode'] == 'threshold':
|
||||||
|
linear_mode_param = float(process['linear_threshold'])
|
||||||
|
conv_mode_param = float(process['conv_threshold'])
|
||||||
|
elif process['mode'] == 'ratio':
|
||||||
|
linear_mode_param = float(process['linear_ratio'])
|
||||||
|
conv_mode_param = float(process['conv_ratio'])
|
||||||
|
elif process['mode'] == 'quantile':
|
||||||
|
linear_mode_param = float(process['linear_quantile'])
|
||||||
|
conv_mode_param = float(process['conv_quantile'])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown mode: {process['mode']}")
|
||||||
|
|
||||||
|
print(f"Running process: {process['mode']}, lin: {linear_mode_param}, conv: {conv_mode_param}")
|
||||||
|
|
||||||
|
state_dict, extract_diff_meta = extract_diff(
|
||||||
|
base,
|
||||||
|
extract,
|
||||||
|
process['mode'],
|
||||||
|
linear_mode_param,
|
||||||
|
conv_mode_param,
|
||||||
|
device,
|
||||||
|
use_sparse_bias,
|
||||||
|
sparsity,
|
||||||
|
not disable_cp
|
||||||
|
)
|
||||||
|
|
||||||
|
save_meta = create_meta([
|
||||||
|
item_meta, extract_diff_meta
|
||||||
|
])
|
||||||
|
|
||||||
|
output_file_name = f"lyco_{name}_{process['mode']}_{linear_mode_param}_{conv_mode_param}.safetensors"
|
||||||
|
output_path = os.path.join(output_folder, output_file_name)
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# having issues with meta
|
||||||
|
save_file(state_dict, output_path)
|
||||||
|
# save_file(state_dict, output_path, {'meta': json.dumps(save_meta, indent=4)})
|
||||||
|
|
||||||
|
print(f"Saved to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
0
toolkit/__init__.py
Normal file
0
toolkit/__init__.py
Normal file
39
toolkit/config.py
Normal file
39
toolkit/config.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from toolkit.paths import TOOLKIT_ROOT
|
||||||
|
|
||||||
|
possible_extensions = ['.json', '.jsonc']
|
||||||
|
|
||||||
|
|
||||||
|
def get_cwd_abs_path(path):
|
||||||
|
if not os.path.isabs(path):
|
||||||
|
path = os.path.join(os.getcwd(), path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(config_file_path):
|
||||||
|
# first check if it is in the config folder
|
||||||
|
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
|
||||||
|
# see if it is in the config folder with any of the possible extensions if it doesnt have one
|
||||||
|
real_config_path = None
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
for ext in possible_extensions:
|
||||||
|
if os.path.exists(config_path + ext):
|
||||||
|
real_config_path = config_path + ext
|
||||||
|
break
|
||||||
|
|
||||||
|
# if we didn't find it there, check if it is a full path
|
||||||
|
if not real_config_path:
|
||||||
|
if os.path.exists(config_file_path):
|
||||||
|
real_config_path = config_file_path
|
||||||
|
elif os.path.exists(get_cwd_abs_path(config_file_path)):
|
||||||
|
real_config_path = get_cwd_abs_path(config_file_path)
|
||||||
|
|
||||||
|
if not real_config_path:
|
||||||
|
raise ValueError(f"Could not find config file {config_file_path}")
|
||||||
|
|
||||||
|
# load the config
|
||||||
|
with open(real_config_path, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
return config
|
||||||
1180
toolkit/kohya_model_util.py
Normal file
1180
toolkit/kohya_model_util.py
Normal file
File diff suppressed because it is too large
Load Diff
512
toolkit/lycoris_utils.py
Normal file
512
toolkit/lycoris_utils.py
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
# heavily based on https://github.com/KohakuBlueleaf/LyCORIS/blob/main/lycoris/utils.py
|
||||||
|
|
||||||
|
from typing import *
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import torch.linalg as linalg
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def make_sparse(t: torch.Tensor, sparsity=0.95):
|
||||||
|
abs_t = torch.abs(t)
|
||||||
|
np_array = abs_t.detach().cpu().numpy()
|
||||||
|
quan = float(np.quantile(np_array, sparsity))
|
||||||
|
sparse_t = t.masked_fill(abs_t < quan, 0)
|
||||||
|
return sparse_t
|
||||||
|
|
||||||
|
|
||||||
|
def extract_conv(
|
||||||
|
weight: Union[torch.Tensor, nn.Parameter],
|
||||||
|
mode='fixed',
|
||||||
|
mode_param=0,
|
||||||
|
device='cpu',
|
||||||
|
is_cp=False,
|
||||||
|
) -> Tuple[nn.Parameter, nn.Parameter]:
|
||||||
|
weight = weight.to(device)
|
||||||
|
out_ch, in_ch, kernel_size, _ = weight.shape
|
||||||
|
|
||||||
|
U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
|
||||||
|
|
||||||
|
if mode == 'fixed':
|
||||||
|
lora_rank = mode_param
|
||||||
|
elif mode == 'threshold':
|
||||||
|
assert mode_param >= 0
|
||||||
|
lora_rank = torch.sum(S > mode_param)
|
||||||
|
elif mode == 'ratio':
|
||||||
|
assert 1 >= mode_param >= 0
|
||||||
|
min_s = torch.max(S) * mode_param
|
||||||
|
lora_rank = torch.sum(S > min_s)
|
||||||
|
elif mode == 'quantile' or mode == 'percentile':
|
||||||
|
assert 1 >= mode_param >= 0
|
||||||
|
s_cum = torch.cumsum(S, dim=0)
|
||||||
|
min_cum_sum = mode_param * torch.sum(S)
|
||||||
|
lora_rank = torch.sum(s_cum < min_cum_sum)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
||||||
|
lora_rank = max(1, lora_rank)
|
||||||
|
lora_rank = min(out_ch, in_ch, lora_rank)
|
||||||
|
if lora_rank >= out_ch / 2 and not is_cp:
|
||||||
|
return weight, 'full'
|
||||||
|
|
||||||
|
U = U[:, :lora_rank]
|
||||||
|
S = S[:lora_rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:lora_rank, :]
|
||||||
|
|
||||||
|
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
|
||||||
|
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
|
||||||
|
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
|
||||||
|
del U, S, Vh, weight
|
||||||
|
return (extract_weight_A, extract_weight_B, diff), 'low rank'
|
||||||
|
|
||||||
|
|
||||||
|
def extract_linear(
|
||||||
|
weight: Union[torch.Tensor, nn.Parameter],
|
||||||
|
mode='fixed',
|
||||||
|
mode_param=0,
|
||||||
|
device='cpu',
|
||||||
|
) -> Tuple[nn.Parameter, nn.Parameter]:
|
||||||
|
weight = weight.to(device)
|
||||||
|
out_ch, in_ch = weight.shape
|
||||||
|
|
||||||
|
U, S, Vh = linalg.svd(weight)
|
||||||
|
|
||||||
|
if mode == 'fixed':
|
||||||
|
lora_rank = mode_param
|
||||||
|
elif mode == 'threshold':
|
||||||
|
assert mode_param >= 0
|
||||||
|
lora_rank = torch.sum(S > mode_param)
|
||||||
|
elif mode == 'ratio':
|
||||||
|
assert 1 >= mode_param >= 0
|
||||||
|
min_s = torch.max(S) * mode_param
|
||||||
|
lora_rank = torch.sum(S > min_s)
|
||||||
|
elif mode == 'quantile' or mode == 'percentile':
|
||||||
|
assert 1 >= mode_param >= 0
|
||||||
|
s_cum = torch.cumsum(S, dim=0)
|
||||||
|
min_cum_sum = mode_param * torch.sum(S)
|
||||||
|
lora_rank = torch.sum(s_cum < min_cum_sum)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
||||||
|
lora_rank = max(1, lora_rank)
|
||||||
|
lora_rank = min(out_ch, in_ch, lora_rank)
|
||||||
|
if lora_rank >= out_ch / 2:
|
||||||
|
return weight, 'full'
|
||||||
|
|
||||||
|
U = U[:, :lora_rank]
|
||||||
|
S = S[:lora_rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:lora_rank, :]
|
||||||
|
|
||||||
|
diff = (weight - U @ Vh).detach()
|
||||||
|
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
|
||||||
|
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
|
||||||
|
del U, S, Vh, weight
|
||||||
|
return (extract_weight_A, extract_weight_B, diff), 'low rank'
|
||||||
|
|
||||||
|
|
||||||
|
def extract_diff(
|
||||||
|
base_model,
|
||||||
|
db_model,
|
||||||
|
mode='fixed',
|
||||||
|
linear_mode_param=0,
|
||||||
|
conv_mode_param=0,
|
||||||
|
extract_device='cpu',
|
||||||
|
use_bias=False,
|
||||||
|
sparsity=0.98,
|
||||||
|
small_conv=True
|
||||||
|
):
|
||||||
|
meta = {}
|
||||||
|
|
||||||
|
UNET_TARGET_REPLACE_MODULE = [
|
||||||
|
"Transformer2DModel",
|
||||||
|
"Attention",
|
||||||
|
"ResnetBlock2D",
|
||||||
|
"Downsample2D",
|
||||||
|
"Upsample2D"
|
||||||
|
]
|
||||||
|
UNET_TARGET_REPLACE_NAME = [
|
||||||
|
"conv_in",
|
||||||
|
"conv_out",
|
||||||
|
"time_embedding.linear_1",
|
||||||
|
"time_embedding.linear_2",
|
||||||
|
]
|
||||||
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
|
LORA_PREFIX_UNET = 'lora_unet'
|
||||||
|
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||||
|
|
||||||
|
def make_state_dict(
|
||||||
|
prefix,
|
||||||
|
root_module: torch.nn.Module,
|
||||||
|
target_module: torch.nn.Module,
|
||||||
|
target_replace_modules,
|
||||||
|
target_replace_names=[]
|
||||||
|
):
|
||||||
|
loras = {}
|
||||||
|
temp = {}
|
||||||
|
temp_name = {}
|
||||||
|
|
||||||
|
for name, module in root_module.named_modules():
|
||||||
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
|
temp[name] = {}
|
||||||
|
for child_name, child_module in module.named_modules():
|
||||||
|
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
||||||
|
continue
|
||||||
|
temp[name][child_name] = child_module.weight
|
||||||
|
elif name in target_replace_names:
|
||||||
|
temp_name[name] = module.weight
|
||||||
|
|
||||||
|
for name, module in tqdm(list(target_module.named_modules())):
|
||||||
|
if name in temp:
|
||||||
|
weights = temp[name]
|
||||||
|
for child_name, child_module in module.named_modules():
|
||||||
|
lora_name = prefix + '.' + name + '.' + child_name
|
||||||
|
lora_name = lora_name.replace('.', '_')
|
||||||
|
layer = child_module.__class__.__name__
|
||||||
|
if layer in {'Linear', 'Conv2d'}:
|
||||||
|
root_weight = child_module.weight
|
||||||
|
if torch.allclose(root_weight, weights[child_name]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if layer == 'Linear':
|
||||||
|
weight, decompose_mode = extract_linear(
|
||||||
|
(child_module.weight - weights[child_name]),
|
||||||
|
mode,
|
||||||
|
linear_mode_param,
|
||||||
|
device=extract_device,
|
||||||
|
)
|
||||||
|
if decompose_mode == 'low rank':
|
||||||
|
extract_a, extract_b, diff = weight
|
||||||
|
elif layer == 'Conv2d':
|
||||||
|
is_linear = (child_module.weight.shape[2] == 1
|
||||||
|
and child_module.weight.shape[3] == 1)
|
||||||
|
weight, decompose_mode = extract_conv(
|
||||||
|
(child_module.weight - weights[child_name]),
|
||||||
|
mode,
|
||||||
|
linear_mode_param if is_linear else conv_mode_param,
|
||||||
|
device=extract_device,
|
||||||
|
)
|
||||||
|
if decompose_mode == 'low rank':
|
||||||
|
extract_a, extract_b, diff = weight
|
||||||
|
if small_conv and not is_linear and decompose_mode == 'low rank':
|
||||||
|
dim = extract_a.size(0)
|
||||||
|
(extract_c, extract_a, _), _ = extract_conv(
|
||||||
|
extract_a.transpose(0, 1),
|
||||||
|
'fixed', dim,
|
||||||
|
extract_device, True
|
||||||
|
)
|
||||||
|
extract_a = extract_a.transpose(0, 1)
|
||||||
|
extract_c = extract_c.transpose(0, 1)
|
||||||
|
loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
|
||||||
|
diff = child_module.weight - torch.einsum(
|
||||||
|
'i j k l, j r, p i -> p r k l',
|
||||||
|
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
|
||||||
|
).detach().cpu().contiguous()
|
||||||
|
del extract_c
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if decompose_mode == 'low rank':
|
||||||
|
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
|
||||||
|
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
|
||||||
|
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
|
||||||
|
if use_bias:
|
||||||
|
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
||||||
|
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
||||||
|
|
||||||
|
indices = sparse_diff.indices().to(torch.int16)
|
||||||
|
values = sparse_diff.values().half()
|
||||||
|
loras[f'{lora_name}.bias_indices'] = indices
|
||||||
|
loras[f'{lora_name}.bias_values'] = values
|
||||||
|
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
|
||||||
|
del extract_a, extract_b, diff
|
||||||
|
elif decompose_mode == 'full':
|
||||||
|
loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
elif name in temp_name:
|
||||||
|
weights = temp_name[name]
|
||||||
|
lora_name = prefix + '.' + name
|
||||||
|
lora_name = lora_name.replace('.', '_')
|
||||||
|
layer = module.__class__.__name__
|
||||||
|
|
||||||
|
if layer in {'Linear', 'Conv2d'}:
|
||||||
|
root_weight = module.weight
|
||||||
|
if torch.allclose(root_weight, weights):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if layer == 'Linear':
|
||||||
|
weight, decompose_mode = extract_linear(
|
||||||
|
(root_weight - weights),
|
||||||
|
mode,
|
||||||
|
linear_mode_param,
|
||||||
|
device=extract_device,
|
||||||
|
)
|
||||||
|
if decompose_mode == 'low rank':
|
||||||
|
extract_a, extract_b, diff = weight
|
||||||
|
elif layer == 'Conv2d':
|
||||||
|
is_linear = (
|
||||||
|
root_weight.shape[2] == 1
|
||||||
|
and root_weight.shape[3] == 1
|
||||||
|
)
|
||||||
|
weight, decompose_mode = extract_conv(
|
||||||
|
(root_weight - weights),
|
||||||
|
mode,
|
||||||
|
linear_mode_param if is_linear else conv_mode_param,
|
||||||
|
device=extract_device,
|
||||||
|
)
|
||||||
|
if decompose_mode == 'low rank':
|
||||||
|
extract_a, extract_b, diff = weight
|
||||||
|
if small_conv and not is_linear and decompose_mode == 'low rank':
|
||||||
|
dim = extract_a.size(0)
|
||||||
|
(extract_c, extract_a, _), _ = extract_conv(
|
||||||
|
extract_a.transpose(0, 1),
|
||||||
|
'fixed', dim,
|
||||||
|
extract_device, True
|
||||||
|
)
|
||||||
|
extract_a = extract_a.transpose(0, 1)
|
||||||
|
extract_c = extract_c.transpose(0, 1)
|
||||||
|
loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
|
||||||
|
diff = root_weight - torch.einsum(
|
||||||
|
'i j k l, j r, p i -> p r k l',
|
||||||
|
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
|
||||||
|
).detach().cpu().contiguous()
|
||||||
|
del extract_c
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if decompose_mode == 'low rank':
|
||||||
|
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
|
||||||
|
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
|
||||||
|
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
|
||||||
|
if use_bias:
|
||||||
|
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
||||||
|
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
||||||
|
|
||||||
|
indices = sparse_diff.indices().to(torch.int16)
|
||||||
|
values = sparse_diff.values().half()
|
||||||
|
loras[f'{lora_name}.bias_indices'] = indices
|
||||||
|
loras[f'{lora_name}.bias_values'] = values
|
||||||
|
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
|
||||||
|
del extract_a, extract_b, diff
|
||||||
|
elif decompose_mode == 'full':
|
||||||
|
loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return loras
|
||||||
|
|
||||||
|
text_encoder_loras = make_state_dict(
|
||||||
|
LORA_PREFIX_TEXT_ENCODER,
|
||||||
|
base_model[0], db_model[0],
|
||||||
|
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||||
|
)
|
||||||
|
|
||||||
|
unet_loras = make_state_dict(
|
||||||
|
LORA_PREFIX_UNET,
|
||||||
|
base_model[2], db_model[2],
|
||||||
|
UNET_TARGET_REPLACE_MODULE,
|
||||||
|
UNET_TARGET_REPLACE_NAME
|
||||||
|
)
|
||||||
|
print(len(text_encoder_loras), len(unet_loras))
|
||||||
|
# the | will
|
||||||
|
return (text_encoder_loras | unet_loras), meta
|
||||||
|
|
||||||
|
|
||||||
|
def get_module(
|
||||||
|
lyco_state_dict: Dict,
|
||||||
|
lora_name
|
||||||
|
):
|
||||||
|
if f'{lora_name}.lora_up.weight' in lyco_state_dict:
|
||||||
|
up = lyco_state_dict[f'{lora_name}.lora_up.weight']
|
||||||
|
down = lyco_state_dict[f'{lora_name}.lora_down.weight']
|
||||||
|
mid = lyco_state_dict.get(f'{lora_name}.lora_mid.weight', None)
|
||||||
|
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
|
||||||
|
return 'locon', (up, down, mid, alpha)
|
||||||
|
elif f'{lora_name}.hada_w1_a' in lyco_state_dict:
|
||||||
|
w1a = lyco_state_dict[f'{lora_name}.hada_w1_a']
|
||||||
|
w1b = lyco_state_dict[f'{lora_name}.hada_w1_b']
|
||||||
|
w2a = lyco_state_dict[f'{lora_name}.hada_w2_a']
|
||||||
|
w2b = lyco_state_dict[f'{lora_name}.hada_w2_b']
|
||||||
|
t1 = lyco_state_dict.get(f'{lora_name}.hada_t1', None)
|
||||||
|
t2 = lyco_state_dict.get(f'{lora_name}.hada_t2', None)
|
||||||
|
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
|
||||||
|
return 'hada', (w1a, w1b, w2a, w2b, t1, t2, alpha)
|
||||||
|
elif f'{lora_name}.weight' in lyco_state_dict:
|
||||||
|
weight = lyco_state_dict[f'{lora_name}.weight']
|
||||||
|
on_input = lyco_state_dict.get(f'{lora_name}.on_input', False)
|
||||||
|
return 'ia3', (weight, on_input)
|
||||||
|
elif (f'{lora_name}.lokr_w1' in lyco_state_dict
|
||||||
|
or f'{lora_name}.lokr_w1_a' in lyco_state_dict):
|
||||||
|
w1 = lyco_state_dict.get(f'{lora_name}.lokr_w1', None)
|
||||||
|
w1a = lyco_state_dict.get(f'{lora_name}.lokr_w1_a', None)
|
||||||
|
w1b = lyco_state_dict.get(f'{lora_name}.lokr_w1_b', None)
|
||||||
|
w2 = lyco_state_dict.get(f'{lora_name}.lokr_w2', None)
|
||||||
|
w2a = lyco_state_dict.get(f'{lora_name}.lokr_w2_a', None)
|
||||||
|
w2b = lyco_state_dict.get(f'{lora_name}.lokr_w2_b', None)
|
||||||
|
t1 = lyco_state_dict.get(f'{lora_name}.lokr_t1', None)
|
||||||
|
t2 = lyco_state_dict.get(f'{lora_name}.lokr_t2', None)
|
||||||
|
alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
|
||||||
|
return 'kron', (w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha)
|
||||||
|
elif f'{lora_name}.diff' in lyco_state_dict:
|
||||||
|
return 'full', lyco_state_dict[f'{lora_name}.diff']
|
||||||
|
else:
|
||||||
|
return 'None', ()
|
||||||
|
|
||||||
|
|
||||||
|
def cp_weight_from_conv(
|
||||||
|
up, down, mid
|
||||||
|
):
|
||||||
|
up = up.reshape(up.size(0), up.size(1))
|
||||||
|
down = down.reshape(down.size(0), down.size(1))
|
||||||
|
return torch.einsum('m n w h, i m, n j -> i j w h', mid, up, down)
|
||||||
|
|
||||||
|
|
||||||
|
def cp_weight(
|
||||||
|
wa, wb, t
|
||||||
|
):
|
||||||
|
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
||||||
|
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def rebuild_weight(module_type, params, orig_weight, scale=1):
|
||||||
|
if orig_weight is None:
|
||||||
|
return orig_weight
|
||||||
|
merged = orig_weight
|
||||||
|
if module_type == 'locon':
|
||||||
|
up, down, mid, alpha = params
|
||||||
|
if alpha is not None:
|
||||||
|
scale *= alpha / up.size(1)
|
||||||
|
if mid is not None:
|
||||||
|
rebuild = cp_weight_from_conv(up, down, mid)
|
||||||
|
else:
|
||||||
|
rebuild = up.reshape(up.size(0), -1) @ down.reshape(down.size(0), -1)
|
||||||
|
merged = orig_weight + rebuild.reshape(orig_weight.shape) * scale
|
||||||
|
del up, down, mid, alpha, params, rebuild
|
||||||
|
elif module_type == 'hada':
|
||||||
|
w1a, w1b, w2a, w2b, t1, t2, alpha = params
|
||||||
|
if alpha is not None:
|
||||||
|
scale *= alpha / w1b.size(0)
|
||||||
|
if t1 is not None:
|
||||||
|
rebuild1 = cp_weight(w1a, w1b, t1)
|
||||||
|
else:
|
||||||
|
rebuild1 = w1a @ w1b
|
||||||
|
if t2 is not None:
|
||||||
|
rebuild2 = cp_weight(w2a, w2b, t2)
|
||||||
|
else:
|
||||||
|
rebuild2 = w2a @ w2b
|
||||||
|
rebuild = (rebuild1 * rebuild2).reshape(orig_weight.shape)
|
||||||
|
merged = orig_weight + rebuild * scale
|
||||||
|
del w1a, w1b, w2a, w2b, t1, t2, alpha, params, rebuild, rebuild1, rebuild2
|
||||||
|
elif module_type == 'ia3':
|
||||||
|
weight, on_input = params
|
||||||
|
if not on_input:
|
||||||
|
weight = weight.reshape(-1, 1)
|
||||||
|
merged = orig_weight + weight * orig_weight * scale
|
||||||
|
del weight, on_input, params
|
||||||
|
elif module_type == 'kron':
|
||||||
|
w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha = params
|
||||||
|
if alpha is not None and (w1b is not None or w2b is not None):
|
||||||
|
scale *= alpha / (w1b.size(0) if w1b else w2b.size(0))
|
||||||
|
if w1a is not None and w1b is not None:
|
||||||
|
if t1:
|
||||||
|
w1 = cp_weight(w1a, w1b, t1)
|
||||||
|
else:
|
||||||
|
w1 = w1a @ w1b
|
||||||
|
if w2a is not None and w2b is not None:
|
||||||
|
if t2:
|
||||||
|
w2 = cp_weight(w2a, w2b, t2)
|
||||||
|
else:
|
||||||
|
w2 = w2a @ w2b
|
||||||
|
rebuild = torch.kron(w1, w2).reshape(orig_weight.shape)
|
||||||
|
merged = orig_weight + rebuild * scale
|
||||||
|
del w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha, params, rebuild
|
||||||
|
elif module_type == 'full':
|
||||||
|
rebuild = params.reshape(orig_weight.shape)
|
||||||
|
merged = orig_weight + rebuild * scale
|
||||||
|
del params, rebuild
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
base_model,
|
||||||
|
lyco_state_dict,
|
||||||
|
scale: float = 1.0,
|
||||||
|
device='cpu'
|
||||||
|
):
|
||||||
|
UNET_TARGET_REPLACE_MODULE = [
|
||||||
|
"Transformer2DModel",
|
||||||
|
"Attention",
|
||||||
|
"ResnetBlock2D",
|
||||||
|
"Downsample2D",
|
||||||
|
"Upsample2D"
|
||||||
|
]
|
||||||
|
UNET_TARGET_REPLACE_NAME = [
|
||||||
|
"conv_in",
|
||||||
|
"conv_out",
|
||||||
|
"time_embedding.linear_1",
|
||||||
|
"time_embedding.linear_2",
|
||||||
|
]
|
||||||
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
|
LORA_PREFIX_UNET = 'lora_unet'
|
||||||
|
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||||
|
merged = 0
|
||||||
|
|
||||||
|
def merge_state_dict(
|
||||||
|
prefix,
|
||||||
|
root_module: torch.nn.Module,
|
||||||
|
lyco_state_dict: Dict[str, torch.Tensor],
|
||||||
|
target_replace_modules,
|
||||||
|
target_replace_names=[]
|
||||||
|
):
|
||||||
|
nonlocal merged
|
||||||
|
for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'):
|
||||||
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
|
for child_name, child_module in module.named_modules():
|
||||||
|
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
||||||
|
continue
|
||||||
|
lora_name = prefix + '.' + name + '.' + child_name
|
||||||
|
lora_name = lora_name.replace('.', '_')
|
||||||
|
|
||||||
|
result = rebuild_weight(*get_module(
|
||||||
|
lyco_state_dict, lora_name
|
||||||
|
), getattr(child_module, 'weight'), scale)
|
||||||
|
if result is not None:
|
||||||
|
merged += 1
|
||||||
|
child_module.requires_grad_(False)
|
||||||
|
child_module.weight.copy_(result)
|
||||||
|
elif name in target_replace_names:
|
||||||
|
lora_name = prefix + '.' + name
|
||||||
|
lora_name = lora_name.replace('.', '_')
|
||||||
|
|
||||||
|
result = rebuild_weight(*get_module(
|
||||||
|
lyco_state_dict, lora_name
|
||||||
|
), getattr(module, 'weight'), scale)
|
||||||
|
if result is not None:
|
||||||
|
merged += 1
|
||||||
|
module.requires_grad_(False)
|
||||||
|
module.weight.copy_(result)
|
||||||
|
|
||||||
|
if device == 'cpu':
|
||||||
|
for k, v in tqdm(list(lyco_state_dict.items()), desc='Converting Dtype'):
|
||||||
|
lyco_state_dict[k] = v.float()
|
||||||
|
|
||||||
|
merge_state_dict(
|
||||||
|
LORA_PREFIX_TEXT_ENCODER,
|
||||||
|
base_model[0],
|
||||||
|
lyco_state_dict,
|
||||||
|
TEXT_ENCODER_TARGET_REPLACE_MODULE,
|
||||||
|
UNET_TARGET_REPLACE_NAME
|
||||||
|
)
|
||||||
|
merge_state_dict(
|
||||||
|
LORA_PREFIX_UNET,
|
||||||
|
base_model[2],
|
||||||
|
lyco_state_dict,
|
||||||
|
UNET_TARGET_REPLACE_MODULE,
|
||||||
|
UNET_TARGET_REPLACE_NAME
|
||||||
|
)
|
||||||
|
print(f'{merged} Modules been merged')
|
||||||
24
toolkit/metadata.py
Normal file
24
toolkit/metadata.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
software_meta = {
|
||||||
|
"name": "ai-toolkit",
|
||||||
|
"url": "https://github.com/ostris/ai-toolkit"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_meta(dict_list, name=None):
|
||||||
|
meta = {}
|
||||||
|
for d in dict_list:
|
||||||
|
for key, value in d.items():
|
||||||
|
meta[key] = value
|
||||||
|
|
||||||
|
if "name" not in meta:
|
||||||
|
meta["name"] = "[name]"
|
||||||
|
|
||||||
|
meta["software"] = software_meta
|
||||||
|
|
||||||
|
# convert to string to handle replacements
|
||||||
|
meta_string = json.dumps(meta)
|
||||||
|
if name is not None:
|
||||||
|
meta_string = meta_string.replace("[name]", name)
|
||||||
|
return json.loads(meta_string)
|
||||||
4
toolkit/paths.py
Normal file
4
toolkit/paths.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
||||||
Reference in New Issue
Block a user