mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Reworked so everything is in classes for easy expansion. Single entry point for all config files now.
This commit is contained in:
@@ -1,150 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from flatten_json import flatten
|
||||
|
||||
sys.path.insert(0, os.getcwd())
|
||||
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||
CONFIG_FOLDER = os.path.join(PROJECT_ROOT, 'config')
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
import argparse
|
||||
|
||||
from toolkit.lycoris_utils import extract_diff
|
||||
from toolkit.config import get_config
|
||||
from toolkit.metadata import create_meta, prep_meta_for_safetensors
|
||||
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"config_file",
|
||||
help="Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder",
|
||||
type=str
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
config_raw = get_config(args.config_file)
|
||||
config = config_raw['config'] if 'config' in config_raw else None
|
||||
if not config:
|
||||
raise ValueError('config file is invalid. Missing "config" key')
|
||||
|
||||
meta = config_raw['meta'] if 'meta' in config_raw else {}
|
||||
|
||||
def get_conf(key, default=None):
|
||||
if key in config:
|
||||
return config[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
is_v2 = get_conf('is_v2', False)
|
||||
name = get_conf('name', None)
|
||||
base_model = get_conf('base_model')
|
||||
extract_model = get_conf('extract_model')
|
||||
output_folder = get_conf('output_folder')
|
||||
process_list = get_conf('process')
|
||||
device = get_conf('device', 'cpu')
|
||||
use_sparse_bias = get_conf('use_sparse_bias', False)
|
||||
sparsity = get_conf('sparsity', 0.98)
|
||||
disable_cp = get_conf('disable_cp', False)
|
||||
|
||||
if not name:
|
||||
raise ValueError('name is required')
|
||||
if not base_model:
|
||||
raise ValueError('base_model is required')
|
||||
if not extract_model:
|
||||
raise ValueError('extract_model is required')
|
||||
if not output_folder:
|
||||
raise ValueError('output_folder is required')
|
||||
if not process_list or len(process_list) == 0:
|
||||
raise ValueError('process is required')
|
||||
|
||||
# check processes
|
||||
for process in process_list:
|
||||
if process['mode'] == 'fixed':
|
||||
if not process['linear_dim']:
|
||||
raise ValueError('linear_dim is required in fixed mode')
|
||||
if not process['conv_dim']:
|
||||
raise ValueError('conv_dim is required in fixed mode')
|
||||
elif process['mode'] == 'threshold':
|
||||
if not process['linear_threshold']:
|
||||
raise ValueError('linear_threshold is required in threshold mode')
|
||||
if not process['conv_threshold']:
|
||||
raise ValueError('conv_threshold is required in threshold mode')
|
||||
elif process['mode'] == 'ratio':
|
||||
if not process['linear_ratio']:
|
||||
raise ValueError('linear_ratio is required in ratio mode')
|
||||
if not process['conv_ratio']:
|
||||
raise ValueError('conv_threshold is required in threshold mode')
|
||||
elif process['mode'] == 'quantile':
|
||||
if not process['linear_quantile']:
|
||||
raise ValueError('linear_quantile is required in quantile mode')
|
||||
if not process['conv_quantile']:
|
||||
raise ValueError('conv_quantile is required in quantile mode')
|
||||
else:
|
||||
raise ValueError('mode is invalid')
|
||||
|
||||
print(f"Loading base model: {base_model}")
|
||||
base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
|
||||
print(f"Loading extract model: {extract_model}")
|
||||
extract = load_models_from_stable_diffusion_checkpoint(is_v2, extract_model)
|
||||
|
||||
print(f"Running {len(process_list)} process{'' if len(process_list) == 1 else 'es'}")
|
||||
|
||||
for process in process_list:
|
||||
item_meta = json.loads(json.dumps(meta))
|
||||
item_meta['process'] = process
|
||||
if process['mode'] == 'fixed':
|
||||
linear_mode_param = int(process['linear_dim'])
|
||||
conv_mode_param = int(process['conv_dim'])
|
||||
elif process['mode'] == 'threshold':
|
||||
linear_mode_param = float(process['linear_threshold'])
|
||||
conv_mode_param = float(process['conv_threshold'])
|
||||
elif process['mode'] == 'ratio':
|
||||
linear_mode_param = float(process['linear_ratio'])
|
||||
conv_mode_param = float(process['conv_ratio'])
|
||||
elif process['mode'] == 'quantile':
|
||||
linear_mode_param = float(process['linear_quantile'])
|
||||
conv_mode_param = float(process['conv_quantile'])
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {process['mode']}")
|
||||
|
||||
print(f"Running process: {process['mode']}, lin: {linear_mode_param}, conv: {conv_mode_param}")
|
||||
|
||||
state_dict, extract_diff_meta = extract_diff(
|
||||
base,
|
||||
extract,
|
||||
process['mode'],
|
||||
linear_mode_param,
|
||||
conv_mode_param,
|
||||
device,
|
||||
use_sparse_bias,
|
||||
sparsity,
|
||||
not disable_cp
|
||||
)
|
||||
|
||||
save_meta = create_meta([
|
||||
item_meta, extract_diff_meta
|
||||
], name=name)
|
||||
|
||||
output_file_name = f"lyco_{name}_{process['mode']}_{linear_mode_param}_{conv_mode_param}.safetensors"
|
||||
output_path = os.path.join(output_folder, output_file_name)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# having issues with meta
|
||||
save_file(state_dict, output_path, prep_meta_for_safetensors(save_meta))
|
||||
|
||||
print(f"Saved to {output_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user