mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Reworked so everything is in classes for easy expansion. Single entry point for all config files now.
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.paths import TOOLKIT_ROOT
|
||||
|
||||
possible_extensions = ['.json', '.jsonc']
|
||||
@@ -11,6 +13,21 @@ def get_cwd_abs_path(path):
|
||||
return path
|
||||
|
||||
|
||||
def preprocess_config(config: OrderedDict):
|
||||
if "job" not in config:
|
||||
raise ValueError("config file must have a job key")
|
||||
if "config" not in config:
|
||||
raise ValueError("config file must have a config section")
|
||||
if "name" not in config["config"]:
|
||||
raise ValueError("config file must have a config.name key")
|
||||
# we need to replace tags. For now just [name]
|
||||
name = config["config"]["name"]
|
||||
config_string = json.dumps(config)
|
||||
config_string = config_string.replace("[name]", name)
|
||||
config = json.loads(config_string, object_pairs_hook=OrderedDict)
|
||||
return config
|
||||
|
||||
|
||||
def get_config(config_file_path):
|
||||
# first check if it is in the config folder
|
||||
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
|
||||
@@ -34,6 +51,6 @@ def get_config(config_file_path):
|
||||
|
||||
# load the config
|
||||
with open(real_config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
config = json.load(f, object_pairs_hook=OrderedDict)
|
||||
|
||||
return config
|
||||
return preprocess_config(config)
|
||||
|
||||
15
toolkit/job.py
Normal file
15
toolkit/job.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from jobs import BaseJob
|
||||
from toolkit.config import get_config
|
||||
|
||||
|
||||
def get_job(config_path) -> BaseJob:
|
||||
config = get_config(config_path)
|
||||
if not config['job']:
|
||||
raise ValueError('config file is invalid. Missing "job" key')
|
||||
|
||||
job = config['job']
|
||||
if job == 'extract':
|
||||
from jobs import ExtractJob
|
||||
return ExtractJob(config)
|
||||
else:
|
||||
raise ValueError(f'Unknown job type {job}')
|
||||
@@ -11,6 +11,7 @@ import torch.nn.functional as F
|
||||
import torch.linalg as linalg
|
||||
|
||||
from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def make_sparse(t: torch.Tensor, sparsity=0.95):
|
||||
@@ -121,7 +122,7 @@ def extract_diff(
|
||||
sparsity=0.98,
|
||||
small_conv=True
|
||||
):
|
||||
meta = {}
|
||||
meta = OrderedDict()
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
|
||||
@@ -1,31 +1,18 @@
|
||||
import json
|
||||
|
||||
software_meta = {
|
||||
"name": "ai-toolkit",
|
||||
"url": "https://github.com/ostris/ai-toolkit"
|
||||
}
|
||||
from collections import OrderedDict
|
||||
from info import software_meta
|
||||
|
||||
|
||||
def create_meta(dict_list, name=None):
|
||||
meta = {}
|
||||
for d in dict_list:
|
||||
for key, value in d.items():
|
||||
meta[key] = value
|
||||
|
||||
if "name" not in meta:
|
||||
meta["name"] = "[name]"
|
||||
|
||||
meta["software"] = software_meta
|
||||
|
||||
# convert to string to handle replacements
|
||||
def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
|
||||
# stringify the meta and reparse OrderedDict to replace [name] with name
|
||||
meta_string = json.dumps(meta)
|
||||
if name is not None:
|
||||
meta_string = meta_string.replace("[name]", name)
|
||||
return json.loads(meta_string)
|
||||
|
||||
|
||||
def prep_meta_for_safetensors(meta):
|
||||
save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict)
|
||||
save_meta["software"] = software_meta
|
||||
# safetensors can only be one level deep
|
||||
for key, value in meta.items():
|
||||
meta[key] = json.dumps(value)
|
||||
return meta
|
||||
for key, value in save_meta.items():
|
||||
# if not float, int, bool, or str, convert to json string
|
||||
if not isinstance(value, (float, int, bool, str)):
|
||||
save_meta[key] = json.dumps(value)
|
||||
return save_meta
|
||||
|
||||
Reference in New Issue
Block a user