From 37354b006e53479ae981cc871c77d0887b1eb8ae Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 8 Jul 2023 09:51:42 -0600 Subject: [PATCH] Reworked so everything is in classes for easy expansion. Single entry point for all config files now. --- README.md | 4 +- ...ample.json => extract_config.example.json} | 24 +-- info.py | 8 + jobs/BaseJob.py | 43 +++++ jobs/ExtractJob.py | 53 +++++++ jobs/__init__.py | 2 + jobs/process/BaseExtractProcess.py | 76 +++++++++ jobs/process/BaseProcess.py | 42 +++++ jobs/process/LoconExtractProcess.py | 67 ++++++++ jobs/process/__init__.py | 3 + run.py | 67 ++++++++ scripts/extract_locon.py | 150 ------------------ toolkit/config.py | 21 ++- toolkit/job.py | 15 ++ toolkit/lycoris_utils.py | 3 +- toolkit/metadata.py | 35 ++-- 16 files changed, 424 insertions(+), 189 deletions(-) rename config/examples/{locon_config.example.json => extract_config.example.json} (62%) create mode 100644 info.py create mode 100644 jobs/BaseJob.py create mode 100644 jobs/ExtractJob.py create mode 100644 jobs/__init__.py create mode 100644 jobs/process/BaseExtractProcess.py create mode 100644 jobs/process/BaseProcess.py create mode 100644 jobs/process/LoconExtractProcess.py create mode 100644 jobs/process/__init__.py create mode 100644 run.py delete mode 100644 scripts/extract_locon.py create mode 100644 toolkit/job.py diff --git a/README.md b/README.md index f73770b0..d2ec7cba 100644 --- a/README.md +++ b/README.md @@ -33,13 +33,13 @@ Just copy that file, into the `config` folder, and rename it to `whatever_you_wa Then you can edit the file to your liking. and call it like so: ```bash -python3 scripts/extract_locon.py "whatever_you_want" +python3 run.py "whatever_you_want" ``` You can also put a full path to a config file, if you want to keep it somewhere else. ```bash -python3 scripts/extract_locon.py "/home/user/whatever_you_want.json" +python3 run.py "/home/user/whatever_you_want.json" ``` File name is auto generated and dumped into the `output` folder. You can put whatever meta you want in the diff --git a/config/examples/locon_config.example.json b/config/examples/extract_config.example.json similarity index 62% rename from config/examples/locon_config.example.json rename to config/examples/extract_config.example.json index 142919af..e08a639f 100644 --- a/config/examples/locon_config.example.json +++ b/config/examples/extract_config.example.json @@ -1,4 +1,5 @@ { + "job": "extract", "config": { "name": "name_of_your_model", "base_model": "/path/to/base/model", @@ -6,24 +7,26 @@ "output_folder": "/path/to/output/folder", "is_v2": false, "device": "cpu", - "use_sparse_bias": false, - "sparsity": 0.98, - "disable_cp": false, "process": [ { + "filename":"[name]_64_32.safetensors", + "type": "locon", "mode": "fixed", - "linear_dim": 64, - "conv_dim": 32 + "linear": 64, + "conv": 32 }, { + "output_path": "/absolute/path/for/this/output.safetensors", + "type": "locon", "mode": "ratio", - "linear_ratio": 0.2, - "conv_ratio": 0.2 + "linear": 0.2, + "conv": 0.2 }, { + "type": "locon", "mode": "quantile", - "linear_quantile": 0.5, - "conv_quantile": 0.5 + "linear": 0.5, + "conv": 0.5 } ] }, @@ -41,6 +44,7 @@ "name": "Your Name", "email": "your@email.com", "website": "https://yourwebsite.com" - } + }, + "any": "All meta data above is arbitrary, it can be whatever you want." } } \ No newline at end of file diff --git a/info.py b/info.py new file mode 100644 index 00000000..2e3c824c --- /dev/null +++ b/info.py @@ -0,0 +1,8 @@ +from collections import OrderedDict + +v = OrderedDict() +v["name"] = "ai-toolkit" +v["repo"] = "https://github.com/ostris/ai-toolkit" +v["version"] = "0.0.1" + +software_meta = v diff --git a/jobs/BaseJob.py b/jobs/BaseJob.py new file mode 100644 index 00000000..c5a52c69 --- /dev/null +++ b/jobs/BaseJob.py @@ -0,0 +1,43 @@ +from collections import OrderedDict + + +class BaseJob: + config: OrderedDict + job: str + name: str + meta: OrderedDict + + def __init__(self, config: OrderedDict): + if not config: + raise ValueError('config is required') + + self.config = config['config'] + self.job = config['job'] + self.name = self.get_conf('name', required=True) + if 'meta' in config: + self.meta = config['meta'] + else: + self.meta = OrderedDict() + + def get_conf(self, key, default=None, required=False): + if key in self.config: + return self.config[key] + elif required: + raise ValueError(f'config file error. Missing "config.{key}" key') + else: + return default + + def run(self): + print("") + print(f"#############################################") + print(f"# Running job: {self.name}") + print(f"#############################################") + print("") + # implement in child class + # be sure to call super().run() first + pass + + def cleanup(self): + # if you implement this in child clas, + # be sure to call super().cleanup() LAST + del self diff --git a/jobs/ExtractJob.py b/jobs/ExtractJob.py new file mode 100644 index 00000000..1b6af8fb --- /dev/null +++ b/jobs/ExtractJob.py @@ -0,0 +1,53 @@ +from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint +from .BaseJob import BaseJob +from collections import OrderedDict +from typing import List + +from jobs.process import BaseExtractProcess + + +class ExtractJob(BaseJob): + process: List[BaseExtractProcess] + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.base_model_path = self.get_conf('base_model', required=True) + self.base_model = None + self.extract_model_path = self.get_conf('extract_model', required=True) + self.extract_model = None + self.output_folder = self.get_conf('output_folder', required=True) + self.is_v2 = self.get_conf('is_v2', False) + self.device = self.get_conf('device', 'cpu') + + if 'process' not in self.config: + raise ValueError('config file is invalid. Missing "config.process" key') + if len(self.config['process']) == 0: + raise ValueError('config file is invalid. "config.process" must be a list of processes') + + # add the processes + self.process = [] + for i, process in enumerate(self.config['process']): + if 'type' not in process: + raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key') + if process['type'] == 'locon': + from jobs.process import LoconExtractProcess + self.process.append(LoconExtractProcess(i, self, process)) + else: + raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}') + + def run(self): + super().run() + # load models + print(f"Loading models for extraction") + print(f" - Loading base model: {self.base_model_path}") + self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) + + print(f" - Loading extract model: {self.extract_model_path}") + self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path) + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() + diff --git a/jobs/__init__.py b/jobs/__init__.py new file mode 100644 index 00000000..09be1770 --- /dev/null +++ b/jobs/__init__.py @@ -0,0 +1,2 @@ +from .BaseJob import BaseJob +from .ExtractJob import ExtractJob diff --git a/jobs/process/BaseExtractProcess.py b/jobs/process/BaseExtractProcess.py new file mode 100644 index 00000000..1f141bb5 --- /dev/null +++ b/jobs/process/BaseExtractProcess.py @@ -0,0 +1,76 @@ +import os +from collections import OrderedDict + +from safetensors.torch import save_file + +from jobs import ExtractJob +from jobs.process.BaseProcess import BaseProcess +from toolkit.metadata import get_meta_for_safetensors + + +class BaseExtractProcess(BaseProcess): + job: ExtractJob + process_id: int + config: OrderedDict + output_folder: str + output_filename: str + output_path: str + + def __init__( + self, + process_id: int, + job: ExtractJob, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.process_id = process_id + self.job = job + self.config = config + + def run(self): + # here instead of init because child init needs to go first + self.output_path = self.get_output_path() + # implement in child class + # be sure to call super().run() first + pass + + # you can override this in the child class if you want + # call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this + def get_output_path(self, prefix=None, suffix=None): + config_output_path = self.get_conf('output_path', None) + config_filename = self.get_conf('filename', None) + # replace [name] with name + + if config_output_path is not None: + config_output_path = config_output_path.replace('[name]', self.job.name) + return config_output_path + + if config_output_path is None and config_filename is not None: + # build the output path from the output folder and filename + return os.path.join(self.job.output_folder, config_filename) + + # build our own + + if suffix is None: + # we will just add process it to the end of the filename if there is more than one process + # and no other suffix was given + suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else '' + + if prefix is None: + prefix = '' + + output_filename = f"{prefix}{self.output_filename}{suffix}" + + return os.path.join(self.job.output_folder, output_filename) + + def save(self, state_dict): + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + + # save + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + + # having issues with meta + save_file(state_dict, self.output_path, save_meta) + + print(f"Saved to {self.output_path}") diff --git a/jobs/process/BaseProcess.py b/jobs/process/BaseProcess.py new file mode 100644 index 00000000..a09de583 --- /dev/null +++ b/jobs/process/BaseProcess.py @@ -0,0 +1,42 @@ +import copy +import json +from collections import OrderedDict + +from jobs import BaseJob + + +class BaseProcess: + meta: OrderedDict + + def __init__( + self, + process_id: int, + job: BaseJob, + config: OrderedDict + ): + self.process_id = process_id + self.job = job + self.config = config + self.meta = copy.deepcopy(self.job.meta) + + def get_conf(self, key, default=None, required=False, as_type=None): + if key in self.config: + value = self.config[key] + if as_type is not None: + value = as_type(value) + return value + elif required: + raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key') + else: + if as_type is not None: + return as_type(default) + return default + + def run(self): + # implement in child class + # be sure to call super().run() first incase something is added here + pass + + def add_meta(self, additional_meta: OrderedDict): + self.meta.update(additional_meta) + diff --git a/jobs/process/LoconExtractProcess.py b/jobs/process/LoconExtractProcess.py new file mode 100644 index 00000000..c2133bdf --- /dev/null +++ b/jobs/process/LoconExtractProcess.py @@ -0,0 +1,67 @@ +from collections import OrderedDict +from toolkit.lycoris_utils import extract_diff +from .BaseExtractProcess import BaseExtractProcess +from .. import ExtractJob + +mode_dict = { + 'fixed': { + 'linear': 64, + 'conv': 32, + 'type': int + }, + 'threshold': { + 'linear': 0, + 'conv': 0, + 'type': float + }, + 'ratio': { + 'linear': 0.5, + 'conv': 0.5, + 'type': float + }, + 'quantile': { + 'linear': 0.5, + 'conv': 0.5, + 'type': float + } +} + + +class LoconExtractProcess(BaseExtractProcess): + def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict): + super().__init__(process_id, job, config) + self.mode = self.get_conf('mode', 'fixed') + self.use_sparse_bias = self.get_conf('use_sparse_bias', False) + self.sparsity = self.get_conf('sparsity', 0.98) + self.disable_cp = self.get_conf('disable_cp', False) + + # set modes + if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']: + raise ValueError(f"Unknown mode: {self.mode}") + self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type']) + self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type']) + + def run(self): + super().run() + print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}") + + state_dict, extract_diff_meta = extract_diff( + self.job.base_model, + self.job.extract_model, + self.mode, + self.linear_param, + self.conv_param, + self.job.device, + self.use_sparse_bias, + self.sparsity, + not self.disable_cp + ) + + self.add_meta(extract_diff_meta) + self.save(state_dict) + + def get_output_path(self, prefix=None, suffix=None): + if suffix is None: + suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}" + return super().get_output_path(prefix, suffix) + diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py new file mode 100644 index 00000000..c480e0ac --- /dev/null +++ b/jobs/process/__init__.py @@ -0,0 +1,3 @@ +from .BaseExtractProcess import BaseExtractProcess +from .LoconExtractProcess import LoconExtractProcess +from .BaseProcess import BaseProcess diff --git a/run.py b/run.py new file mode 100644 index 00000000..e269bbb6 --- /dev/null +++ b/run.py @@ -0,0 +1,67 @@ +import os +import sys +from collections import OrderedDict + +from jobs import BaseJob + +sys.path.insert(0, os.getcwd()) +import argparse +from toolkit.job import get_job + + +def print_end_message(jobs_completed, jobs_failed): + failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" + completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" + + print("") + print("========================================") + print("Result:") + if len(completed_string) > 0: + print(f" - {completed_string}") + if len(failure_string) > 0: + print(f" - {failure_string}") + print("========================================") + + +def main(): + parser = argparse.ArgumentParser() + + # require at lease one config file + parser.add_argument( + 'config_file_list', + nargs='+', + type=str, + help='Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' + ) + + # flag to continue if failed job + parser.add_argument( + '-r', '--recover', + action='store_true', + help='Continue running additional jobs even if a job fails' + ) + args = parser.parse_args() + + config_file_list = args.config_file_list + if len(config_file_list) == 0: + raise Exception("You must provide at least one config file") + + jobs_completed = 0 + jobs_failed = 0 + + for config_file in config_file_list: + try: + job = get_job(config_file) + job.run() + job.cleanup() + jobs_completed += 1 + except Exception as e: + print(f"Error running job: {e}") + jobs_failed += 1 + if not args.recover: + print_end_message(jobs_completed, jobs_failed) + raise e + + +if __name__ == '__main__': + main() diff --git a/scripts/extract_locon.py b/scripts/extract_locon.py deleted file mode 100644 index 33d9df65..00000000 --- a/scripts/extract_locon.py +++ /dev/null @@ -1,150 +0,0 @@ -import json -import os -import sys - -from flatten_json import flatten - -sys.path.insert(0, os.getcwd()) -PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) -CONFIG_FOLDER = os.path.join(PROJECT_ROOT, 'config') -sys.path.append(PROJECT_ROOT) - -import argparse - -from toolkit.lycoris_utils import extract_diff -from toolkit.config import get_config -from toolkit.metadata import create_meta, prep_meta_for_safetensors -from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint - -import torch -from safetensors.torch import save_file - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "config_file", - help="Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder", - type=str - ) - return parser.parse_args() - - -def main(): - args = get_args() - - config_raw = get_config(args.config_file) - config = config_raw['config'] if 'config' in config_raw else None - if not config: - raise ValueError('config file is invalid. Missing "config" key') - - meta = config_raw['meta'] if 'meta' in config_raw else {} - - def get_conf(key, default=None): - if key in config: - return config[key] - else: - return default - - is_v2 = get_conf('is_v2', False) - name = get_conf('name', None) - base_model = get_conf('base_model') - extract_model = get_conf('extract_model') - output_folder = get_conf('output_folder') - process_list = get_conf('process') - device = get_conf('device', 'cpu') - use_sparse_bias = get_conf('use_sparse_bias', False) - sparsity = get_conf('sparsity', 0.98) - disable_cp = get_conf('disable_cp', False) - - if not name: - raise ValueError('name is required') - if not base_model: - raise ValueError('base_model is required') - if not extract_model: - raise ValueError('extract_model is required') - if not output_folder: - raise ValueError('output_folder is required') - if not process_list or len(process_list) == 0: - raise ValueError('process is required') - - # check processes - for process in process_list: - if process['mode'] == 'fixed': - if not process['linear_dim']: - raise ValueError('linear_dim is required in fixed mode') - if not process['conv_dim']: - raise ValueError('conv_dim is required in fixed mode') - elif process['mode'] == 'threshold': - if not process['linear_threshold']: - raise ValueError('linear_threshold is required in threshold mode') - if not process['conv_threshold']: - raise ValueError('conv_threshold is required in threshold mode') - elif process['mode'] == 'ratio': - if not process['linear_ratio']: - raise ValueError('linear_ratio is required in ratio mode') - if not process['conv_ratio']: - raise ValueError('conv_threshold is required in threshold mode') - elif process['mode'] == 'quantile': - if not process['linear_quantile']: - raise ValueError('linear_quantile is required in quantile mode') - if not process['conv_quantile']: - raise ValueError('conv_quantile is required in quantile mode') - else: - raise ValueError('mode is invalid') - - print(f"Loading base model: {base_model}") - base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model) - print(f"Loading extract model: {extract_model}") - extract = load_models_from_stable_diffusion_checkpoint(is_v2, extract_model) - - print(f"Running {len(process_list)} process{'' if len(process_list) == 1 else 'es'}") - - for process in process_list: - item_meta = json.loads(json.dumps(meta)) - item_meta['process'] = process - if process['mode'] == 'fixed': - linear_mode_param = int(process['linear_dim']) - conv_mode_param = int(process['conv_dim']) - elif process['mode'] == 'threshold': - linear_mode_param = float(process['linear_threshold']) - conv_mode_param = float(process['conv_threshold']) - elif process['mode'] == 'ratio': - linear_mode_param = float(process['linear_ratio']) - conv_mode_param = float(process['conv_ratio']) - elif process['mode'] == 'quantile': - linear_mode_param = float(process['linear_quantile']) - conv_mode_param = float(process['conv_quantile']) - else: - raise ValueError(f"Unknown mode: {process['mode']}") - - print(f"Running process: {process['mode']}, lin: {linear_mode_param}, conv: {conv_mode_param}") - - state_dict, extract_diff_meta = extract_diff( - base, - extract, - process['mode'], - linear_mode_param, - conv_mode_param, - device, - use_sparse_bias, - sparsity, - not disable_cp - ) - - save_meta = create_meta([ - item_meta, extract_diff_meta - ], name=name) - - output_file_name = f"lyco_{name}_{process['mode']}_{linear_mode_param}_{conv_mode_param}.safetensors" - output_path = os.path.join(output_folder, output_file_name) - os.makedirs(output_folder, exist_ok=True) - - # having issues with meta - save_file(state_dict, output_path, prep_meta_for_safetensors(save_meta)) - - print(f"Saved to {output_path}") - - -if __name__ == '__main__': - main() diff --git a/toolkit/config.py b/toolkit/config.py index 168a2bc1..b3116bac 100644 --- a/toolkit/config.py +++ b/toolkit/config.py @@ -1,5 +1,7 @@ import os import json +from collections import OrderedDict + from toolkit.paths import TOOLKIT_ROOT possible_extensions = ['.json', '.jsonc'] @@ -11,6 +13,21 @@ def get_cwd_abs_path(path): return path +def preprocess_config(config: OrderedDict): + if "job" not in config: + raise ValueError("config file must have a job key") + if "config" not in config: + raise ValueError("config file must have a config section") + if "name" not in config["config"]: + raise ValueError("config file must have a config.name key") + # we need to replace tags. For now just [name] + name = config["config"]["name"] + config_string = json.dumps(config) + config_string = config_string.replace("[name]", name) + config = json.loads(config_string, object_pairs_hook=OrderedDict) + return config + + def get_config(config_file_path): # first check if it is in the config folder config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path) @@ -34,6 +51,6 @@ def get_config(config_file_path): # load the config with open(real_config_path, 'r') as f: - config = json.load(f) + config = json.load(f, object_pairs_hook=OrderedDict) - return config + return preprocess_config(config) diff --git a/toolkit/job.py b/toolkit/job.py new file mode 100644 index 00000000..5ac7e0c5 --- /dev/null +++ b/toolkit/job.py @@ -0,0 +1,15 @@ +from jobs import BaseJob +from toolkit.config import get_config + + +def get_job(config_path) -> BaseJob: + config = get_config(config_path) + if not config['job']: + raise ValueError('config file is invalid. Missing "job" key') + + job = config['job'] + if job == 'extract': + from jobs import ExtractJob + return ExtractJob(config) + else: + raise ValueError(f'Unknown job type {job}') diff --git a/toolkit/lycoris_utils.py b/toolkit/lycoris_utils.py index a852e09c..c85db570 100644 --- a/toolkit/lycoris_utils.py +++ b/toolkit/lycoris_utils.py @@ -11,6 +11,7 @@ import torch.nn.functional as F import torch.linalg as linalg from tqdm import tqdm +from collections import OrderedDict def make_sparse(t: torch.Tensor, sparsity=0.95): @@ -121,7 +122,7 @@ def extract_diff( sparsity=0.98, small_conv=True ): - meta = {} + meta = OrderedDict() UNET_TARGET_REPLACE_MODULE = [ "Transformer2DModel", diff --git a/toolkit/metadata.py b/toolkit/metadata.py index d30acefb..0a99da70 100644 --- a/toolkit/metadata.py +++ b/toolkit/metadata.py @@ -1,31 +1,18 @@ import json - -software_meta = { - "name": "ai-toolkit", - "url": "https://github.com/ostris/ai-toolkit" -} +from collections import OrderedDict +from info import software_meta -def create_meta(dict_list, name=None): - meta = {} - for d in dict_list: - for key, value in d.items(): - meta[key] = value - - if "name" not in meta: - meta["name"] = "[name]" - - meta["software"] = software_meta - - # convert to string to handle replacements +def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict: + # stringify the meta and reparse OrderedDict to replace [name] with name meta_string = json.dumps(meta) if name is not None: meta_string = meta_string.replace("[name]", name) - return json.loads(meta_string) - - -def prep_meta_for_safetensors(meta): + save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict) + save_meta["software"] = software_meta # safetensors can only be one level deep - for key, value in meta.items(): - meta[key] = json.dumps(value) - return meta + for key, value in save_meta.items(): + # if not float, int, bool, or str, convert to json string + if not isinstance(value, (float, int, bool, str)): + save_meta[key] = json.dumps(value) + return save_meta