Reworked so everything is in classes for easy expansion. Single entry point for all config files now.

This commit is contained in:
Jaret Burkett
2023-07-08 09:51:42 -06:00
parent 27df03a486
commit 37354b006e
16 changed files with 424 additions and 189 deletions

View File

@@ -1,5 +1,7 @@
import os
import json
from collections import OrderedDict
from toolkit.paths import TOOLKIT_ROOT
possible_extensions = ['.json', '.jsonc']
@@ -11,6 +13,21 @@ def get_cwd_abs_path(path):
return path
def preprocess_config(config: OrderedDict):
if "job" not in config:
raise ValueError("config file must have a job key")
if "config" not in config:
raise ValueError("config file must have a config section")
if "name" not in config["config"]:
raise ValueError("config file must have a config.name key")
# we need to replace tags. For now just [name]
name = config["config"]["name"]
config_string = json.dumps(config)
config_string = config_string.replace("[name]", name)
config = json.loads(config_string, object_pairs_hook=OrderedDict)
return config
def get_config(config_file_path):
# first check if it is in the config folder
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
@@ -34,6 +51,6 @@ def get_config(config_file_path):
# load the config
with open(real_config_path, 'r') as f:
config = json.load(f)
config = json.load(f, object_pairs_hook=OrderedDict)
return config
return preprocess_config(config)

15
toolkit/job.py Normal file
View File

@@ -0,0 +1,15 @@
from jobs import BaseJob
from toolkit.config import get_config
def get_job(config_path) -> BaseJob:
config = get_config(config_path)
if not config['job']:
raise ValueError('config file is invalid. Missing "job" key')
job = config['job']
if job == 'extract':
from jobs import ExtractJob
return ExtractJob(config)
else:
raise ValueError(f'Unknown job type {job}')

View File

@@ -11,6 +11,7 @@ import torch.nn.functional as F
import torch.linalg as linalg
from tqdm import tqdm
from collections import OrderedDict
def make_sparse(t: torch.Tensor, sparsity=0.95):
@@ -121,7 +122,7 @@ def extract_diff(
sparsity=0.98,
small_conv=True
):
meta = {}
meta = OrderedDict()
UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",

View File

@@ -1,31 +1,18 @@
import json
software_meta = {
"name": "ai-toolkit",
"url": "https://github.com/ostris/ai-toolkit"
}
from collections import OrderedDict
from info import software_meta
def create_meta(dict_list, name=None):
meta = {}
for d in dict_list:
for key, value in d.items():
meta[key] = value
if "name" not in meta:
meta["name"] = "[name]"
meta["software"] = software_meta
# convert to string to handle replacements
def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
# stringify the meta and reparse OrderedDict to replace [name] with name
meta_string = json.dumps(meta)
if name is not None:
meta_string = meta_string.replace("[name]", name)
return json.loads(meta_string)
def prep_meta_for_safetensors(meta):
save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict)
save_meta["software"] = software_meta
# safetensors can only be one level deep
for key, value in meta.items():
meta[key] = json.dumps(value)
return meta
for key, value in save_meta.items():
# if not float, int, bool, or str, convert to json string
if not isinstance(value, (float, int, bool, str)):
save_meta[key] = json.dumps(value)
return save_meta