Files
ai-toolkit/scripts/extract_locon.py
2023-07-05 16:49:59 -06:00

151 lines
5.4 KiB
Python

import json
import os
import sys
from flatten_json import flatten
sys.path.insert(0, os.getcwd())
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
CONFIG_FOLDER = os.path.join(PROJECT_ROOT, 'config')
sys.path.append(PROJECT_ROOT)
import argparse
from toolkit.lycoris_utils import extract_diff
from toolkit.config import get_config
from toolkit.metadata import create_meta, prep_meta_for_safetensors
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
import torch
from safetensors.torch import save_file
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"config_file",
help="Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder",
type=str
)
return parser.parse_args()
def main():
args = get_args()
config_raw = get_config(args.config_file)
config = config_raw['config'] if 'config' in config_raw else None
if not config:
raise ValueError('config file is invalid. Missing "config" key')
meta = config_raw['meta'] if 'meta' in config_raw else {}
def get_conf(key, default=None):
if key in config:
return config[key]
else:
return default
is_v2 = get_conf('is_v2', False)
name = get_conf('name', None)
base_model = get_conf('base_model')
extract_model = get_conf('extract_model')
output_folder = get_conf('output_folder')
process_list = get_conf('process')
device = get_conf('device', 'cpu')
use_sparse_bias = get_conf('use_sparse_bias', False)
sparsity = get_conf('sparsity', 0.98)
disable_cp = get_conf('disable_cp', False)
if not name:
raise ValueError('name is required')
if not base_model:
raise ValueError('base_model is required')
if not extract_model:
raise ValueError('extract_model is required')
if not output_folder:
raise ValueError('output_folder is required')
if not process_list or len(process_list) == 0:
raise ValueError('process is required')
# check processes
for process in process_list:
if process['mode'] == 'fixed':
if not process['linear_dim']:
raise ValueError('linear_dim is required in fixed mode')
if not process['conv_dim']:
raise ValueError('conv_dim is required in fixed mode')
elif process['mode'] == 'threshold':
if not process['linear_threshold']:
raise ValueError('linear_threshold is required in threshold mode')
if not process['conv_threshold']:
raise ValueError('conv_threshold is required in threshold mode')
elif process['mode'] == 'ratio':
if not process['linear_ratio']:
raise ValueError('linear_ratio is required in ratio mode')
if not process['conv_ratio']:
raise ValueError('conv_threshold is required in threshold mode')
elif process['mode'] == 'quantile':
if not process['linear_quantile']:
raise ValueError('linear_quantile is required in quantile mode')
if not process['conv_quantile']:
raise ValueError('conv_quantile is required in quantile mode')
else:
raise ValueError('mode is invalid')
print(f"Loading base model: {base_model}")
base = load_models_from_stable_diffusion_checkpoint(is_v2, base_model)
print(f"Loading extract model: {extract_model}")
extract = load_models_from_stable_diffusion_checkpoint(is_v2, extract_model)
print(f"Running {len(process_list)} process{'' if len(process_list) == 1 else 'es'}")
for process in process_list:
item_meta = json.loads(json.dumps(meta))
item_meta['process'] = process
if process['mode'] == 'fixed':
linear_mode_param = int(process['linear_dim'])
conv_mode_param = int(process['conv_dim'])
elif process['mode'] == 'threshold':
linear_mode_param = float(process['linear_threshold'])
conv_mode_param = float(process['conv_threshold'])
elif process['mode'] == 'ratio':
linear_mode_param = float(process['linear_ratio'])
conv_mode_param = float(process['conv_ratio'])
elif process['mode'] == 'quantile':
linear_mode_param = float(process['linear_quantile'])
conv_mode_param = float(process['conv_quantile'])
else:
raise ValueError(f"Unknown mode: {process['mode']}")
print(f"Running process: {process['mode']}, lin: {linear_mode_param}, conv: {conv_mode_param}")
state_dict, extract_diff_meta = extract_diff(
base,
extract,
process['mode'],
linear_mode_param,
conv_mode_param,
device,
use_sparse_bias,
sparsity,
not disable_cp
)
save_meta = create_meta([
item_meta, extract_diff_meta
])
output_file_name = f"lyco_{name}_{process['mode']}_{linear_mode_param}_{conv_mode_param}.safetensors"
output_path = os.path.join(output_folder, output_file_name)
os.makedirs(output_folder, exist_ok=True)
# having issues with meta
save_file(state_dict, output_path, prep_meta_for_safetensors(save_meta))
print(f"Saved to {output_path}")
if __name__ == '__main__':
main()