mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Reworked so everything is in classes for easy expansion. Single entry point for all config files now.
This commit is contained in:
@@ -33,13 +33,13 @@ Just copy that file, into the `config` folder, and rename it to `whatever_you_wa
|
||||
Then you can edit the file to your liking. and call it like so:
|
||||
|
||||
```bash
|
||||
python3 scripts/extract_locon.py "whatever_you_want"
|
||||
python3 run.py "whatever_you_want"
|
||||
```
|
||||
|
||||
You can also put a full path to a config file, if you want to keep it somewhere else.
|
||||
|
||||
```bash
|
||||
python3 scripts/extract_locon.py "/home/user/whatever_you_want.json"
|
||||
python3 run.py "/home/user/whatever_you_want.json"
|
||||
```
|
||||
|
||||
File name is auto generated and dumped into the `output` folder. You can put whatever meta you want in the
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{
|
||||
"job": "extract",
|
||||
"config": {
|
||||
"name": "name_of_your_model",
|
||||
"base_model": "/path/to/base/model",
|
||||
@@ -6,24 +7,26 @@
|
||||
"output_folder": "/path/to/output/folder",
|
||||
"is_v2": false,
|
||||
"device": "cpu",
|
||||
"use_sparse_bias": false,
|
||||
"sparsity": 0.98,
|
||||
"disable_cp": false,
|
||||
"process": [
|
||||
{
|
||||
"filename":"[name]_64_32.safetensors",
|
||||
"type": "locon",
|
||||
"mode": "fixed",
|
||||
"linear_dim": 64,
|
||||
"conv_dim": 32
|
||||
"linear": 64,
|
||||
"conv": 32
|
||||
},
|
||||
{
|
||||
"output_path": "/absolute/path/for/this/output.safetensors",
|
||||
"type": "locon",
|
||||
"mode": "ratio",
|
||||
"linear_ratio": 0.2,
|
||||
"conv_ratio": 0.2
|
||||
"linear": 0.2,
|
||||
"conv": 0.2
|
||||
},
|
||||
{
|
||||
"type": "locon",
|
||||
"mode": "quantile",
|
||||
"linear_quantile": 0.5,
|
||||
"conv_quantile": 0.5
|
||||
"linear": 0.5,
|
||||
"conv": 0.5
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -41,6 +44,7 @@
|
||||
"name": "Your Name",
|
||||
"email": "your@email.com",
|
||||
"website": "https://yourwebsite.com"
|
||||
}
|
||||
},
|
||||
"any": "All meta data above is arbitrary, it can be whatever you want."
|
||||
}
|
||||
}
|
||||
8
info.py
Normal file
8
info.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
v = OrderedDict()
|
||||
v["name"] = "ai-toolkit"
|
||||
v["repo"] = "https://github.com/ostris/ai-toolkit"
|
||||
v["version"] = "0.0.1"
|
||||
|
||||
software_meta = v
|
||||
43
jobs/BaseJob.py
Normal file
43
jobs/BaseJob.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class BaseJob:
|
||||
config: OrderedDict
|
||||
job: str
|
||||
name: str
|
||||
meta: OrderedDict
|
||||
|
||||
def __init__(self, config: OrderedDict):
|
||||
if not config:
|
||||
raise ValueError('config is required')
|
||||
|
||||
self.config = config['config']
|
||||
self.job = config['job']
|
||||
self.name = self.get_conf('name', required=True)
|
||||
if 'meta' in config:
|
||||
self.meta = config['meta']
|
||||
else:
|
||||
self.meta = OrderedDict()
|
||||
|
||||
def get_conf(self, key, default=None, required=False):
|
||||
if key in self.config:
|
||||
return self.config[key]
|
||||
elif required:
|
||||
raise ValueError(f'config file error. Missing "config.{key}" key')
|
||||
else:
|
||||
return default
|
||||
|
||||
def run(self):
|
||||
print("")
|
||||
print(f"#############################################")
|
||||
print(f"# Running job: {self.name}")
|
||||
print(f"#############################################")
|
||||
print("")
|
||||
# implement in child class
|
||||
# be sure to call super().run() first
|
||||
pass
|
||||
|
||||
def cleanup(self):
|
||||
# if you implement this in child clas,
|
||||
# be sure to call super().cleanup() LAST
|
||||
del self
|
||||
53
jobs/ExtractJob.py
Normal file
53
jobs/ExtractJob.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
|
||||
from .BaseJob import BaseJob
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
from jobs.process import BaseExtractProcess
|
||||
|
||||
|
||||
class ExtractJob(BaseJob):
|
||||
process: List[BaseExtractProcess]
|
||||
|
||||
def __init__(self, config: OrderedDict):
|
||||
super().__init__(config)
|
||||
self.base_model_path = self.get_conf('base_model', required=True)
|
||||
self.base_model = None
|
||||
self.extract_model_path = self.get_conf('extract_model', required=True)
|
||||
self.extract_model = None
|
||||
self.output_folder = self.get_conf('output_folder', required=True)
|
||||
self.is_v2 = self.get_conf('is_v2', False)
|
||||
self.device = self.get_conf('device', 'cpu')
|
||||
|
||||
if 'process' not in self.config:
|
||||
raise ValueError('config file is invalid. Missing "config.process" key')
|
||||
if len(self.config['process']) == 0:
|
||||
raise ValueError('config file is invalid. "config.process" must be a list of processes')
|
||||
|
||||
# add the processes
|
||||
self.process = []
|
||||
for i, process in enumerate(self.config['process']):
|
||||
if 'type' not in process:
|
||||
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
|
||||
if process['type'] == 'locon':
|
||||
from jobs.process import LoconExtractProcess
|
||||
self.process.append(LoconExtractProcess(i, self, process))
|
||||
else:
|
||||
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
# load models
|
||||
print(f"Loading models for extraction")
|
||||
print(f" - Loading base model: {self.base_model_path}")
|
||||
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
|
||||
|
||||
print(f" - Loading extract model: {self.extract_model_path}")
|
||||
self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
|
||||
|
||||
print("")
|
||||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
|
||||
|
||||
for process in self.process:
|
||||
process.run()
|
||||
|
||||
2
jobs/__init__.py
Normal file
2
jobs/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .BaseJob import BaseJob
|
||||
from .ExtractJob import ExtractJob
|
||||
76
jobs/process/BaseExtractProcess.py
Normal file
76
jobs/process/BaseExtractProcess.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from jobs import ExtractJob
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
|
||||
|
||||
class BaseExtractProcess(BaseProcess):
|
||||
job: ExtractJob
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
output_folder: str
|
||||
output_filename: str
|
||||
output_path: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job: ExtractJob,
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.process_id = process_id
|
||||
self.job = job
|
||||
self.config = config
|
||||
|
||||
def run(self):
|
||||
# here instead of init because child init needs to go first
|
||||
self.output_path = self.get_output_path()
|
||||
# implement in child class
|
||||
# be sure to call super().run() first
|
||||
pass
|
||||
|
||||
# you can override this in the child class if you want
|
||||
# call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this
|
||||
def get_output_path(self, prefix=None, suffix=None):
|
||||
config_output_path = self.get_conf('output_path', None)
|
||||
config_filename = self.get_conf('filename', None)
|
||||
# replace [name] with name
|
||||
|
||||
if config_output_path is not None:
|
||||
config_output_path = config_output_path.replace('[name]', self.job.name)
|
||||
return config_output_path
|
||||
|
||||
if config_output_path is None and config_filename is not None:
|
||||
# build the output path from the output folder and filename
|
||||
return os.path.join(self.job.output_folder, config_filename)
|
||||
|
||||
# build our own
|
||||
|
||||
if suffix is None:
|
||||
# we will just add process it to the end of the filename if there is more than one process
|
||||
# and no other suffix was given
|
||||
suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else ''
|
||||
|
||||
if prefix is None:
|
||||
prefix = ''
|
||||
|
||||
output_filename = f"{prefix}{self.output_filename}{suffix}"
|
||||
|
||||
return os.path.join(self.job.output_folder, output_filename)
|
||||
|
||||
def save(self, state_dict):
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
|
||||
# save
|
||||
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
|
||||
|
||||
# having issues with meta
|
||||
save_file(state_dict, self.output_path, save_meta)
|
||||
|
||||
print(f"Saved to {self.output_path}")
|
||||
42
jobs/process/BaseProcess.py
Normal file
42
jobs/process/BaseProcess.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import copy
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
from jobs import BaseJob
|
||||
|
||||
|
||||
class BaseProcess:
|
||||
meta: OrderedDict
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job: BaseJob,
|
||||
config: OrderedDict
|
||||
):
|
||||
self.process_id = process_id
|
||||
self.job = job
|
||||
self.config = config
|
||||
self.meta = copy.deepcopy(self.job.meta)
|
||||
|
||||
def get_conf(self, key, default=None, required=False, as_type=None):
|
||||
if key in self.config:
|
||||
value = self.config[key]
|
||||
if as_type is not None:
|
||||
value = as_type(value)
|
||||
return value
|
||||
elif required:
|
||||
raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key')
|
||||
else:
|
||||
if as_type is not None:
|
||||
return as_type(default)
|
||||
return default
|
||||
|
||||
def run(self):
|
||||
# implement in child class
|
||||
# be sure to call super().run() first incase something is added here
|
||||
pass
|
||||
|
||||
def add_meta(self, additional_meta: OrderedDict):
|
||||
self.meta.update(additional_meta)
|
||||
|
||||
67
jobs/process/LoconExtractProcess.py
Normal file
67
jobs/process/LoconExtractProcess.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from collections import OrderedDict
|
||||
from toolkit.lycoris_utils import extract_diff
|
||||
from .BaseExtractProcess import BaseExtractProcess
|
||||
from .. import ExtractJob
|
||||
|
||||
mode_dict = {
|
||||
'fixed': {
|
||||
'linear': 64,
|
||||
'conv': 32,
|
||||
'type': int
|
||||
},
|
||||
'threshold': {
|
||||
'linear': 0,
|
||||
'conv': 0,
|
||||
'type': float
|
||||
},
|
||||
'ratio': {
|
||||
'linear': 0.5,
|
||||
'conv': 0.5,
|
||||
'type': float
|
||||
},
|
||||
'quantile': {
|
||||
'linear': 0.5,
|
||||
'conv': 0.5,
|
||||
'type': float
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LoconExtractProcess(BaseExtractProcess):
|
||||
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.mode = self.get_conf('mode', 'fixed')
|
||||
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
|
||||
self.sparsity = self.get_conf('sparsity', 0.98)
|
||||
self.disable_cp = self.get_conf('disable_cp', False)
|
||||
|
||||
# set modes
|
||||
if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']:
|
||||
raise ValueError(f"Unknown mode: {self.mode}")
|
||||
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type'])
|
||||
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type'])
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}")
|
||||
|
||||
state_dict, extract_diff_meta = extract_diff(
|
||||
self.job.base_model,
|
||||
self.job.extract_model,
|
||||
self.mode,
|
||||
self.linear_param,
|
||||
self.conv_param,
|
||||
self.job.device,
|
||||
self.use_sparse_bias,
|
||||
self.sparsity,
|
||||
not self.disable_cp
|
||||
)
|
||||
|
||||
self.add_meta(extract_diff_meta)
|
||||
self.save(state_dict)
|
||||
|
||||
def get_output_path(self, prefix=None, suffix=None):
|
||||
if suffix is None:
|
||||
suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}"
|
||||
return super().get_output_path(prefix, suffix)
|
||||
|
||||
3
jobs/process/__init__.py
Normal file
3
jobs/process/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .BaseExtractProcess import BaseExtractProcess
|
||||
from .LoconExtractProcess import LoconExtractProcess
|
||||
from .BaseProcess import BaseProcess
|
||||
67
run.py
Normal file
67
run.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
from jobs import BaseJob
|
||||
|
||||
sys.path.insert(0, os.getcwd())
|
||||
import argparse
|
||||
from toolkit.job import get_job
|
||||
|
||||
|
||||
def print_end_message(jobs_completed, jobs_failed):
|
||||
failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else ""
|
||||
completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}"
|
||||
|
||||
print("")
|
||||
print("========================================")
|
||||
print("Result:")
|
||||
if len(completed_string) > 0:
|
||||
print(f" - {completed_string}")
|
||||
if len(failure_string) > 0:
|
||||
print(f" - {failure_string}")
|
||||
print("========================================")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# require at lease one config file
|
||||
parser.add_argument(
|
||||
'config_file_list',
|
||||
nargs='+',
|
||||
type=str,
|
||||
help='Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially'
|
||||
)
|
||||
|
||||
# flag to continue if failed job
|
||||
parser.add_argument(
|
||||
'-r', '--recover',
|
||||
action='store_true',
|
||||
help='Continue running additional jobs even if a job fails'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config_file_list = args.config_file_list
|
||||
if len(config_file_list) == 0:
|
||||
raise Exception("You must provide at least one config file")
|
||||
|
||||
jobs_completed = 0
|
||||
jobs_failed = 0
|
||||
|
||||
for config_file in config_file_list:
|
||||
try:
|
||||
job = get_job(config_file)
|
||||
job.run()
|
||||
job.cleanup()
|
||||
jobs_completed += 1
|
||||
except Exception as e:
|
||||
print(f"Error running job: {e}")
|
||||
jobs_failed += 1
|
||||
if not args.recover:
|
||||
print_end_message(jobs_completed, jobs_failed)
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,150 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from flatten_json import flatten
|
||||
|
||||
sys.path.insert(0, os.getcwd())
|
||||
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||
CONFIG_FOLDER = os.path.join(PROJECT_ROOT, 'config')
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
import argparse
|
||||
|
||||
from toolkit.lycoris_utils import extract_diff
|
||||
from toolkit.config import get_config
|
||||
from toolkit.metadata import create_meta, prep_meta_for_safetensors
|
||||
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"config_file",
|
||||
help="Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder",
|
||||
type=str
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
config_raw = get_config(args.config_file)
|
||||
config = config_raw['config'] if 'config' in config_raw else None
|
||||
if not config:
|
||||
raise ValueError('config file is invalid. Missing "config" key')
|
||||
|
||||
meta = config_raw['meta'] if 'meta' in config_raw else {}
|
||||
|
||||
def get_conf(key, default=None):
|
||||
if key in config:
|
||||
return config[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
is_v2 = get_conf('is_v2', False)
|
||||
name = get_conf('name', None)
|
||||
base_model = get_conf('base_model')
|
||||
extract_model = get_conf('extract_model')
|
||||
output_folder = get_conf('output_folder')
|
||||
process_list = get_conf('process')
|
||||
device = get_conf('device', 'cpu')
|
||||
use_sparse_bias = get_conf('use_sparse_bias', False)
|
||||
sparsity = get_conf('sparsity', 0.98)
|
||||
disable_cp = get_conf('disable_cp', False)
|
||||
|
||||
if not name:
|
||||
raise ValueError('name is required')
|
||||
if not base_model:
|
||||
raise ValueError('base_model is required')
|
||||
if not extract_model:
|
||||
raise ValueError('extract_model is required')
|
||||
if not output_folder:
|
||||
raise ValueError('output_folder is required')
|
||||
if not process_list or len(process_list) == 0:
|
||||
raise ValueError('process is required')
|
||||
|
||||
# check processes
|
||||
for process in process_list:
|
||||
if process['mode'] == 'fixed':
|
||||
if not process['linear_dim']:
|
||||
raise ValueError('linear_dim is required in fixed mode')
|
||||
if not process['conv_dim']:
|
||||
raise ValueError('conv_dim is required in fixed mode')
|
||||
elif process['mode'] == 'threshold':
|
||||
if not process['linear_threshold']:
|
||||
raise ValueError('linear_threshold is required in threshold mode')
|
||||
if not process['conv_threshold']:
|
||||
raise ValueError('conv_threshold is required in threshold mode')
|
||||
elif process['mode'] == 'ratio':
|
||||
if not process['linear_ratio']:
|
||||
raise ValueError('linear_ratio is required in ratio mode')
|
||||
if not process['conv_ratio']:
|
||||
raise ValueError('conv_threshold is required in threshold mode')
|
||||
elif process['mode'] == 'quantile':
|
||||
if not process['linear_quantile']:
|
||||
raise ValueError('linear_quantile is required in quantile mode')
|
||||
if not process['conv_quantile']:
|
||||
raise ValueError('conv_quantile is required in quantile mode')
|
||||
else:
|
||||
raise ValueError('mode is invalid')
|
||||
|
||||
print(f"Loading base model: {base_model}")
|
||||
base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
|
||||
print(f"Loading extract model: {extract_model}")
|
||||
extract = load_models_from_stable_diffusion_checkpoint(is_v2, extract_model)
|
||||
|
||||
print(f"Running {len(process_list)} process{'' if len(process_list) == 1 else 'es'}")
|
||||
|
||||
for process in process_list:
|
||||
item_meta = json.loads(json.dumps(meta))
|
||||
item_meta['process'] = process
|
||||
if process['mode'] == 'fixed':
|
||||
linear_mode_param = int(process['linear_dim'])
|
||||
conv_mode_param = int(process['conv_dim'])
|
||||
elif process['mode'] == 'threshold':
|
||||
linear_mode_param = float(process['linear_threshold'])
|
||||
conv_mode_param = float(process['conv_threshold'])
|
||||
elif process['mode'] == 'ratio':
|
||||
linear_mode_param = float(process['linear_ratio'])
|
||||
conv_mode_param = float(process['conv_ratio'])
|
||||
elif process['mode'] == 'quantile':
|
||||
linear_mode_param = float(process['linear_quantile'])
|
||||
conv_mode_param = float(process['conv_quantile'])
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {process['mode']}")
|
||||
|
||||
print(f"Running process: {process['mode']}, lin: {linear_mode_param}, conv: {conv_mode_param}")
|
||||
|
||||
state_dict, extract_diff_meta = extract_diff(
|
||||
base,
|
||||
extract,
|
||||
process['mode'],
|
||||
linear_mode_param,
|
||||
conv_mode_param,
|
||||
device,
|
||||
use_sparse_bias,
|
||||
sparsity,
|
||||
not disable_cp
|
||||
)
|
||||
|
||||
save_meta = create_meta([
|
||||
item_meta, extract_diff_meta
|
||||
], name=name)
|
||||
|
||||
output_file_name = f"lyco_{name}_{process['mode']}_{linear_mode_param}_{conv_mode_param}.safetensors"
|
||||
output_path = os.path.join(output_folder, output_file_name)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# having issues with meta
|
||||
save_file(state_dict, output_path, prep_meta_for_safetensors(save_meta))
|
||||
|
||||
print(f"Saved to {output_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.paths import TOOLKIT_ROOT
|
||||
|
||||
possible_extensions = ['.json', '.jsonc']
|
||||
@@ -11,6 +13,21 @@ def get_cwd_abs_path(path):
|
||||
return path
|
||||
|
||||
|
||||
def preprocess_config(config: OrderedDict):
|
||||
if "job" not in config:
|
||||
raise ValueError("config file must have a job key")
|
||||
if "config" not in config:
|
||||
raise ValueError("config file must have a config section")
|
||||
if "name" not in config["config"]:
|
||||
raise ValueError("config file must have a config.name key")
|
||||
# we need to replace tags. For now just [name]
|
||||
name = config["config"]["name"]
|
||||
config_string = json.dumps(config)
|
||||
config_string = config_string.replace("[name]", name)
|
||||
config = json.loads(config_string, object_pairs_hook=OrderedDict)
|
||||
return config
|
||||
|
||||
|
||||
def get_config(config_file_path):
|
||||
# first check if it is in the config folder
|
||||
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
|
||||
@@ -34,6 +51,6 @@ def get_config(config_file_path):
|
||||
|
||||
# load the config
|
||||
with open(real_config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
config = json.load(f, object_pairs_hook=OrderedDict)
|
||||
|
||||
return config
|
||||
return preprocess_config(config)
|
||||
|
||||
15
toolkit/job.py
Normal file
15
toolkit/job.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from jobs import BaseJob
|
||||
from toolkit.config import get_config
|
||||
|
||||
|
||||
def get_job(config_path) -> BaseJob:
|
||||
config = get_config(config_path)
|
||||
if not config['job']:
|
||||
raise ValueError('config file is invalid. Missing "job" key')
|
||||
|
||||
job = config['job']
|
||||
if job == 'extract':
|
||||
from jobs import ExtractJob
|
||||
return ExtractJob(config)
|
||||
else:
|
||||
raise ValueError(f'Unknown job type {job}')
|
||||
@@ -11,6 +11,7 @@ import torch.nn.functional as F
|
||||
import torch.linalg as linalg
|
||||
|
||||
from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def make_sparse(t: torch.Tensor, sparsity=0.95):
|
||||
@@ -121,7 +122,7 @@ def extract_diff(
|
||||
sparsity=0.98,
|
||||
small_conv=True
|
||||
):
|
||||
meta = {}
|
||||
meta = OrderedDict()
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
|
||||
@@ -1,31 +1,18 @@
|
||||
import json
|
||||
|
||||
software_meta = {
|
||||
"name": "ai-toolkit",
|
||||
"url": "https://github.com/ostris/ai-toolkit"
|
||||
}
|
||||
from collections import OrderedDict
|
||||
from info import software_meta
|
||||
|
||||
|
||||
def create_meta(dict_list, name=None):
|
||||
meta = {}
|
||||
for d in dict_list:
|
||||
for key, value in d.items():
|
||||
meta[key] = value
|
||||
|
||||
if "name" not in meta:
|
||||
meta["name"] = "[name]"
|
||||
|
||||
meta["software"] = software_meta
|
||||
|
||||
# convert to string to handle replacements
|
||||
def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
|
||||
# stringify the meta and reparse OrderedDict to replace [name] with name
|
||||
meta_string = json.dumps(meta)
|
||||
if name is not None:
|
||||
meta_string = meta_string.replace("[name]", name)
|
||||
return json.loads(meta_string)
|
||||
|
||||
|
||||
def prep_meta_for_safetensors(meta):
|
||||
save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict)
|
||||
save_meta["software"] = software_meta
|
||||
# safetensors can only be one level deep
|
||||
for key, value in meta.items():
|
||||
meta[key] = json.dumps(value)
|
||||
return meta
|
||||
for key, value in save_meta.items():
|
||||
# if not float, int, bool, or str, convert to json string
|
||||
if not isinstance(value, (float, int, bool, str)):
|
||||
save_meta[key] = json.dumps(value)
|
||||
return save_meta
|
||||
|
||||
Reference in New Issue
Block a user