mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merge remote-tracking branch 'origin/main' into WIP
This commit is contained in:
10
run.py
10
run.py
@@ -36,6 +36,14 @@ def main():
|
||||
action='store_true',
|
||||
help='Continue running additional jobs even if a job fails'
|
||||
)
|
||||
|
||||
# flag to continue if failed job
|
||||
parser.add_argument(
|
||||
'-n', '--name',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Name to replace [name] tag in config file, useful for shared config file'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config_file_list = args.config_file_list
|
||||
@@ -49,7 +57,7 @@ def main():
|
||||
|
||||
for config_file in config_file_list:
|
||||
try:
|
||||
job = get_job(config_file)
|
||||
job = get_job(config_file, args.name)
|
||||
job.run()
|
||||
job.cleanup()
|
||||
jobs_completed += 1
|
||||
|
||||
89
testing/compare_keys.py
Normal file
89
testing/compare_keys.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
# this was just used to match the vae keys to the diffusers keys
|
||||
# you probably wont need this. Unless they change them.... again... again
|
||||
# on second thought, you probably will
|
||||
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# require at lease one config file
|
||||
parser.add_argument(
|
||||
'file_1',
|
||||
nargs='+',
|
||||
type=str,
|
||||
help='Path to first safe tensor file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'file_2',
|
||||
nargs='+',
|
||||
type=str,
|
||||
help='Path to second safe tensor file'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
find_matches = False
|
||||
|
||||
state_dict_file_1 = load_file(args.file_1[0])
|
||||
state_dict_1_keys = list(state_dict_file_1.keys())
|
||||
|
||||
state_dict_file_2 = load_file(args.file_2[0])
|
||||
state_dict_2_keys = list(state_dict_file_2.keys())
|
||||
keys_in_both = []
|
||||
|
||||
keys_not_in_state_dict_2 = []
|
||||
for key in state_dict_1_keys:
|
||||
if key not in state_dict_2_keys:
|
||||
keys_not_in_state_dict_2.append(key)
|
||||
|
||||
keys_not_in_state_dict_1 = []
|
||||
for key in state_dict_2_keys:
|
||||
if key not in state_dict_1_keys:
|
||||
keys_not_in_state_dict_1.append(key)
|
||||
|
||||
keys_in_both = []
|
||||
for key in state_dict_1_keys:
|
||||
if key in state_dict_2_keys:
|
||||
keys_in_both.append(key)
|
||||
|
||||
# sort them
|
||||
keys_not_in_state_dict_2.sort()
|
||||
keys_not_in_state_dict_1.sort()
|
||||
keys_in_both.sort()
|
||||
|
||||
|
||||
json_data = {
|
||||
"both": keys_in_both,
|
||||
"state_dict_2": keys_not_in_state_dict_2,
|
||||
"state_dict_1": keys_not_in_state_dict_1
|
||||
}
|
||||
json_data = json.dumps(json_data, indent=4)
|
||||
|
||||
remaining_diffusers_values = OrderedDict()
|
||||
for key in keys_not_in_state_dict_1:
|
||||
remaining_diffusers_values[key] = state_dict_file_2[key]
|
||||
|
||||
# print(remaining_diffusers_values.keys())
|
||||
|
||||
remaining_ldm_values = OrderedDict()
|
||||
for key in keys_not_in_state_dict_2:
|
||||
remaining_ldm_values[key] = state_dict_file_1[key]
|
||||
|
||||
# print(json_data)
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
json_save_path = os.path.join(project_root, 'config', 'keys.json')
|
||||
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
|
||||
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
|
||||
|
||||
with open(json_save_path, 'w') as f:
|
||||
f.write(json_data)
|
||||
@@ -15,22 +15,24 @@ def get_cwd_abs_path(path):
|
||||
return path
|
||||
|
||||
|
||||
def preprocess_config(config: OrderedDict):
|
||||
def preprocess_config(config: OrderedDict, name: str = None):
|
||||
if "job" not in config:
|
||||
raise ValueError("config file must have a job key")
|
||||
if "config" not in config:
|
||||
raise ValueError("config file must have a config section")
|
||||
if "name" not in config["config"]:
|
||||
if "name" not in config["config"] and name is None:
|
||||
raise ValueError("config file must have a config.name key")
|
||||
# we need to replace tags. For now just [name]
|
||||
name = config["config"]["name"]
|
||||
if name is not None:
|
||||
config["config"]["name"] = name
|
||||
else:
|
||||
name = config["config"]["name"]
|
||||
config_string = json.dumps(config)
|
||||
config_string = config_string.replace("[name]", name)
|
||||
config = json.loads(config_string, object_pairs_hook=OrderedDict)
|
||||
return config
|
||||
|
||||
|
||||
|
||||
# Fixes issue where yaml doesnt load exponents correctly
|
||||
fixed_loader = yaml.SafeLoader
|
||||
fixed_loader.add_implicit_resolver(
|
||||
@@ -44,7 +46,8 @@ fixed_loader.add_implicit_resolver(
|
||||
|\\.(?:nan|NaN|NAN))$''', re.X),
|
||||
list(u'-+0123456789.'))
|
||||
|
||||
def get_config(config_file_path):
|
||||
|
||||
def get_config(config_file_path, name=None):
|
||||
# first check if it is in the config folder
|
||||
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
|
||||
# see if it is in the config folder with any of the possible extensions if it doesnt have one
|
||||
@@ -75,4 +78,4 @@ def get_config(config_file_path):
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
|
||||
|
||||
return preprocess_config(config)
|
||||
return preprocess_config(config, name)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from toolkit.config import get_config
|
||||
|
||||
|
||||
def get_job(config_path):
|
||||
config = get_config(config_path)
|
||||
def get_job(config_path, name=None):
|
||||
config = get_config(config_path, name)
|
||||
if not config['job']:
|
||||
raise ValueError('config file is invalid. Missing "job" key')
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user