Reworked so everything is in classes for easy expansion. Single entry point for all config files now.

This commit is contained in:
Jaret Burkett
2023-07-08 09:51:42 -06:00
parent 27df03a486
commit 37354b006e
16 changed files with 424 additions and 189 deletions

View File

@@ -33,13 +33,13 @@ Just copy that file, into the `config` folder, and rename it to `whatever_you_wa
Then you can edit the file to your liking. and call it like so:
```bash
python3 scripts/extract_locon.py "whatever_you_want"
python3 run.py "whatever_you_want"
```
You can also put a full path to a config file, if you want to keep it somewhere else.
```bash
python3 scripts/extract_locon.py "/home/user/whatever_you_want.json"
python3 run.py "/home/user/whatever_you_want.json"
```
File name is auto generated and dumped into the `output` folder. You can put whatever meta you want in the

View File

@@ -1,4 +1,5 @@
{
"job": "extract",
"config": {
"name": "name_of_your_model",
"base_model": "/path/to/base/model",
@@ -6,24 +7,26 @@
"output_folder": "/path/to/output/folder",
"is_v2": false,
"device": "cpu",
"use_sparse_bias": false,
"sparsity": 0.98,
"disable_cp": false,
"process": [
{
"filename":"[name]_64_32.safetensors",
"type": "locon",
"mode": "fixed",
"linear_dim": 64,
"conv_dim": 32
"linear": 64,
"conv": 32
},
{
"output_path": "/absolute/path/for/this/output.safetensors",
"type": "locon",
"mode": "ratio",
"linear_ratio": 0.2,
"conv_ratio": 0.2
"linear": 0.2,
"conv": 0.2
},
{
"type": "locon",
"mode": "quantile",
"linear_quantile": 0.5,
"conv_quantile": 0.5
"linear": 0.5,
"conv": 0.5
}
]
},
@@ -41,6 +44,7 @@
"name": "Your Name",
"email": "your@email.com",
"website": "https://yourwebsite.com"
}
},
"any": "All meta data above is arbitrary, it can be whatever you want."
}
}

8
info.py Normal file
View File

@@ -0,0 +1,8 @@
from collections import OrderedDict
v = OrderedDict()
v["name"] = "ai-toolkit"
v["repo"] = "https://github.com/ostris/ai-toolkit"
v["version"] = "0.0.1"
software_meta = v

43
jobs/BaseJob.py Normal file
View File

@@ -0,0 +1,43 @@
from collections import OrderedDict
class BaseJob:
config: OrderedDict
job: str
name: str
meta: OrderedDict
def __init__(self, config: OrderedDict):
if not config:
raise ValueError('config is required')
self.config = config['config']
self.job = config['job']
self.name = self.get_conf('name', required=True)
if 'meta' in config:
self.meta = config['meta']
else:
self.meta = OrderedDict()
def get_conf(self, key, default=None, required=False):
if key in self.config:
return self.config[key]
elif required:
raise ValueError(f'config file error. Missing "config.{key}" key')
else:
return default
def run(self):
print("")
print(f"#############################################")
print(f"# Running job: {self.name}")
print(f"#############################################")
print("")
# implement in child class
# be sure to call super().run() first
pass
def cleanup(self):
# if you implement this in child clas,
# be sure to call super().cleanup() LAST
del self

53
jobs/ExtractJob.py Normal file
View File

@@ -0,0 +1,53 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from .BaseJob import BaseJob
from collections import OrderedDict
from typing import List
from jobs.process import BaseExtractProcess
class ExtractJob(BaseJob):
process: List[BaseExtractProcess]
def __init__(self, config: OrderedDict):
super().__init__(config)
self.base_model_path = self.get_conf('base_model', required=True)
self.base_model = None
self.extract_model_path = self.get_conf('extract_model', required=True)
self.extract_model = None
self.output_folder = self.get_conf('output_folder', required=True)
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')
if 'process' not in self.config:
raise ValueError('config file is invalid. Missing "config.process" key')
if len(self.config['process']) == 0:
raise ValueError('config file is invalid. "config.process" must be a list of processes')
# add the processes
self.process = []
for i, process in enumerate(self.config['process']):
if 'type' not in process:
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
if process['type'] == 'locon':
from jobs.process import LoconExtractProcess
self.process.append(LoconExtractProcess(i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
def run(self):
super().run()
# load models
print(f"Loading models for extraction")
print(f" - Loading base model: {self.base_model_path}")
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
print(f" - Loading extract model: {self.extract_model_path}")
self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
for process in self.process:
process.run()

2
jobs/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .BaseJob import BaseJob
from .ExtractJob import ExtractJob

View File

@@ -0,0 +1,76 @@
import os
from collections import OrderedDict
from safetensors.torch import save_file
from jobs import ExtractJob
from jobs.process.BaseProcess import BaseProcess
from toolkit.metadata import get_meta_for_safetensors
class BaseExtractProcess(BaseProcess):
job: ExtractJob
process_id: int
config: OrderedDict
output_folder: str
output_filename: str
output_path: str
def __init__(
self,
process_id: int,
job: ExtractJob,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.process_id = process_id
self.job = job
self.config = config
def run(self):
# here instead of init because child init needs to go first
self.output_path = self.get_output_path()
# implement in child class
# be sure to call super().run() first
pass
# you can override this in the child class if you want
# call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this
def get_output_path(self, prefix=None, suffix=None):
config_output_path = self.get_conf('output_path', None)
config_filename = self.get_conf('filename', None)
# replace [name] with name
if config_output_path is not None:
config_output_path = config_output_path.replace('[name]', self.job.name)
return config_output_path
if config_output_path is None and config_filename is not None:
# build the output path from the output folder and filename
return os.path.join(self.job.output_folder, config_filename)
# build our own
if suffix is None:
# we will just add process it to the end of the filename if there is more than one process
# and no other suffix was given
suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else ''
if prefix is None:
prefix = ''
output_filename = f"{prefix}{self.output_filename}{suffix}"
return os.path.join(self.job.output_folder, output_filename)
def save(self, state_dict):
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
# save
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
# having issues with meta
save_file(state_dict, self.output_path, save_meta)
print(f"Saved to {self.output_path}")

View File

@@ -0,0 +1,42 @@
import copy
import json
from collections import OrderedDict
from jobs import BaseJob
class BaseProcess:
meta: OrderedDict
def __init__(
self,
process_id: int,
job: BaseJob,
config: OrderedDict
):
self.process_id = process_id
self.job = job
self.config = config
self.meta = copy.deepcopy(self.job.meta)
def get_conf(self, key, default=None, required=False, as_type=None):
if key in self.config:
value = self.config[key]
if as_type is not None:
value = as_type(value)
return value
elif required:
raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key')
else:
if as_type is not None:
return as_type(default)
return default
def run(self):
# implement in child class
# be sure to call super().run() first incase something is added here
pass
def add_meta(self, additional_meta: OrderedDict):
self.meta.update(additional_meta)

View File

@@ -0,0 +1,67 @@
from collections import OrderedDict
from toolkit.lycoris_utils import extract_diff
from .BaseExtractProcess import BaseExtractProcess
from .. import ExtractJob
mode_dict = {
'fixed': {
'linear': 64,
'conv': 32,
'type': int
},
'threshold': {
'linear': 0,
'conv': 0,
'type': float
},
'ratio': {
'linear': 0.5,
'conv': 0.5,
'type': float
},
'quantile': {
'linear': 0.5,
'conv': 0.5,
'type': float
}
}
class LoconExtractProcess(BaseExtractProcess):
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed')
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
self.sparsity = self.get_conf('sparsity', 0.98)
self.disable_cp = self.get_conf('disable_cp', False)
# set modes
if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']:
raise ValueError(f"Unknown mode: {self.mode}")
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type'])
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type'])
def run(self):
super().run()
print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}")
state_dict, extract_diff_meta = extract_diff(
self.job.base_model,
self.job.extract_model,
self.mode,
self.linear_param,
self.conv_param,
self.job.device,
self.use_sparse_bias,
self.sparsity,
not self.disable_cp
)
self.add_meta(extract_diff_meta)
self.save(state_dict)
def get_output_path(self, prefix=None, suffix=None):
if suffix is None:
suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}"
return super().get_output_path(prefix, suffix)

3
jobs/process/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .BaseExtractProcess import BaseExtractProcess
from .LoconExtractProcess import LoconExtractProcess
from .BaseProcess import BaseProcess

67
run.py Normal file
View File

@@ -0,0 +1,67 @@
import os
import sys
from collections import OrderedDict
from jobs import BaseJob
sys.path.insert(0, os.getcwd())
import argparse
from toolkit.job import get_job
def print_end_message(jobs_completed, jobs_failed):
failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else ""
completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}"
print("")
print("========================================")
print("Result:")
if len(completed_string) > 0:
print(f" - {completed_string}")
if len(failure_string) > 0:
print(f" - {failure_string}")
print("========================================")
def main():
parser = argparse.ArgumentParser()
# require at lease one config file
parser.add_argument(
'config_file_list',
nargs='+',
type=str,
help='Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially'
)
# flag to continue if failed job
parser.add_argument(
'-r', '--recover',
action='store_true',
help='Continue running additional jobs even if a job fails'
)
args = parser.parse_args()
config_file_list = args.config_file_list
if len(config_file_list) == 0:
raise Exception("You must provide at least one config file")
jobs_completed = 0
jobs_failed = 0
for config_file in config_file_list:
try:
job = get_job(config_file)
job.run()
job.cleanup()
jobs_completed += 1
except Exception as e:
print(f"Error running job: {e}")
jobs_failed += 1
if not args.recover:
print_end_message(jobs_completed, jobs_failed)
raise e
if __name__ == '__main__':
main()

View File

@@ -1,150 +0,0 @@
import json
import os
import sys
from flatten_json import flatten
sys.path.insert(0, os.getcwd())
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
CONFIG_FOLDER = os.path.join(PROJECT_ROOT, 'config')
sys.path.append(PROJECT_ROOT)
import argparse
from toolkit.lycoris_utils import extract_diff
from toolkit.config import get_config
from toolkit.metadata import create_meta, prep_meta_for_safetensors
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
import torch
from safetensors.torch import save_file
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"config_file",
help="Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder",
type=str
)
return parser.parse_args()
def main():
args = get_args()
config_raw = get_config(args.config_file)
config = config_raw['config'] if 'config' in config_raw else None
if not config:
raise ValueError('config file is invalid. Missing "config" key')
meta = config_raw['meta'] if 'meta' in config_raw else {}
def get_conf(key, default=None):
if key in config:
return config[key]
else:
return default
is_v2 = get_conf('is_v2', False)
name = get_conf('name', None)
base_model = get_conf('base_model')
extract_model = get_conf('extract_model')
output_folder = get_conf('output_folder')
process_list = get_conf('process')
device = get_conf('device', 'cpu')
use_sparse_bias = get_conf('use_sparse_bias', False)
sparsity = get_conf('sparsity', 0.98)
disable_cp = get_conf('disable_cp', False)
if not name:
raise ValueError('name is required')
if not base_model:
raise ValueError('base_model is required')
if not extract_model:
raise ValueError('extract_model is required')
if not output_folder:
raise ValueError('output_folder is required')
if not process_list or len(process_list) == 0:
raise ValueError('process is required')
# check processes
for process in process_list:
if process['mode'] == 'fixed':
if not process['linear_dim']:
raise ValueError('linear_dim is required in fixed mode')
if not process['conv_dim']:
raise ValueError('conv_dim is required in fixed mode')
elif process['mode'] == 'threshold':
if not process['linear_threshold']:
raise ValueError('linear_threshold is required in threshold mode')
if not process['conv_threshold']:
raise ValueError('conv_threshold is required in threshold mode')
elif process['mode'] == 'ratio':
if not process['linear_ratio']:
raise ValueError('linear_ratio is required in ratio mode')
if not process['conv_ratio']:
raise ValueError('conv_threshold is required in threshold mode')
elif process['mode'] == 'quantile':
if not process['linear_quantile']:
raise ValueError('linear_quantile is required in quantile mode')
if not process['conv_quantile']:
raise ValueError('conv_quantile is required in quantile mode')
else:
raise ValueError('mode is invalid')
print(f"Loading base model: {base_model}")
base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
print(f"Loading extract model: {extract_model}")
extract = load_models_from_stable_diffusion_checkpoint(is_v2, extract_model)
print(f"Running {len(process_list)} process{'' if len(process_list) == 1 else 'es'}")
for process in process_list:
item_meta = json.loads(json.dumps(meta))
item_meta['process'] = process
if process['mode'] == 'fixed':
linear_mode_param = int(process['linear_dim'])
conv_mode_param = int(process['conv_dim'])
elif process['mode'] == 'threshold':
linear_mode_param = float(process['linear_threshold'])
conv_mode_param = float(process['conv_threshold'])
elif process['mode'] == 'ratio':
linear_mode_param = float(process['linear_ratio'])
conv_mode_param = float(process['conv_ratio'])
elif process['mode'] == 'quantile':
linear_mode_param = float(process['linear_quantile'])
conv_mode_param = float(process['conv_quantile'])
else:
raise ValueError(f"Unknown mode: {process['mode']}")
print(f"Running process: {process['mode']}, lin: {linear_mode_param}, conv: {conv_mode_param}")
state_dict, extract_diff_meta = extract_diff(
base,
extract,
process['mode'],
linear_mode_param,
conv_mode_param,
device,
use_sparse_bias,
sparsity,
not disable_cp
)
save_meta = create_meta([
item_meta, extract_diff_meta
], name=name)
output_file_name = f"lyco_{name}_{process['mode']}_{linear_mode_param}_{conv_mode_param}.safetensors"
output_path = os.path.join(output_folder, output_file_name)
os.makedirs(output_folder, exist_ok=True)
# having issues with meta
save_file(state_dict, output_path, prep_meta_for_safetensors(save_meta))
print(f"Saved to {output_path}")
if __name__ == '__main__':
main()

View File

@@ -1,5 +1,7 @@
import os
import json
from collections import OrderedDict
from toolkit.paths import TOOLKIT_ROOT
possible_extensions = ['.json', '.jsonc']
@@ -11,6 +13,21 @@ def get_cwd_abs_path(path):
return path
def preprocess_config(config: OrderedDict):
if "job" not in config:
raise ValueError("config file must have a job key")
if "config" not in config:
raise ValueError("config file must have a config section")
if "name" not in config["config"]:
raise ValueError("config file must have a config.name key")
# we need to replace tags. For now just [name]
name = config["config"]["name"]
config_string = json.dumps(config)
config_string = config_string.replace("[name]", name)
config = json.loads(config_string, object_pairs_hook=OrderedDict)
return config
def get_config(config_file_path):
# first check if it is in the config folder
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
@@ -34,6 +51,6 @@ def get_config(config_file_path):
# load the config
with open(real_config_path, 'r') as f:
config = json.load(f)
config = json.load(f, object_pairs_hook=OrderedDict)
return config
return preprocess_config(config)

15
toolkit/job.py Normal file
View File

@@ -0,0 +1,15 @@
from jobs import BaseJob
from toolkit.config import get_config
def get_job(config_path) -> BaseJob:
config = get_config(config_path)
if not config['job']:
raise ValueError('config file is invalid. Missing "job" key')
job = config['job']
if job == 'extract':
from jobs import ExtractJob
return ExtractJob(config)
else:
raise ValueError(f'Unknown job type {job}')

View File

@@ -11,6 +11,7 @@ import torch.nn.functional as F
import torch.linalg as linalg
from tqdm import tqdm
from collections import OrderedDict
def make_sparse(t: torch.Tensor, sparsity=0.95):
@@ -121,7 +122,7 @@ def extract_diff(
sparsity=0.98,
small_conv=True
):
meta = {}
meta = OrderedDict()
UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",

View File

@@ -1,31 +1,18 @@
import json
software_meta = {
"name": "ai-toolkit",
"url": "https://github.com/ostris/ai-toolkit"
}
from collections import OrderedDict
from info import software_meta
def create_meta(dict_list, name=None):
meta = {}
for d in dict_list:
for key, value in d.items():
meta[key] = value
if "name" not in meta:
meta["name"] = "[name]"
meta["software"] = software_meta
# convert to string to handle replacements
def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
# stringify the meta and reparse OrderedDict to replace [name] with name
meta_string = json.dumps(meta)
if name is not None:
meta_string = meta_string.replace("[name]", name)
return json.loads(meta_string)
def prep_meta_for_safetensors(meta):
save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict)
save_meta["software"] = software_meta
# safetensors can only be one level deep
for key, value in meta.items():
meta[key] = json.dumps(value)
return meta
for key, value in save_meta.items():
# if not float, int, bool, or str, convert to json string
if not isinstance(value, (float, int, bool, str)):
save_meta[key] = json.dumps(value)
return save_meta