mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
121 lines
3.6 KiB
Python
121 lines
3.6 KiB
Python
import os
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
|
|
import sys
|
|
from typing import Union, OrderedDict
|
|
from dotenv import load_dotenv
|
|
# Load the .env file if it exists
|
|
load_dotenv()
|
|
|
|
sys.path.insert(0, os.getcwd())
|
|
# must come before ANY torch or fastai imports
|
|
# import toolkit.cuda_malloc
|
|
|
|
# turn off diffusers telemetry until I can figure out how to make it opt-in
|
|
os.environ['DISABLE_TELEMETRY'] = 'YES'
|
|
|
|
# check if we have DEBUG_TOOLKIT in env
|
|
if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
|
|
# set torch to trace mode
|
|
import torch
|
|
torch.autograd.set_detect_anomaly(True)
|
|
import argparse
|
|
from toolkit.job import get_job
|
|
from toolkit.accelerator import get_accelerator
|
|
from toolkit.print import print_acc, setup_log_to_file
|
|
|
|
accelerator = get_accelerator()
|
|
|
|
|
|
def print_end_message(jobs_completed, jobs_failed):
|
|
if not accelerator.is_main_process:
|
|
return
|
|
failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else ""
|
|
completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}"
|
|
|
|
print_acc("")
|
|
print_acc("========================================")
|
|
print_acc("Result:")
|
|
if len(completed_string) > 0:
|
|
print_acc(f" - {completed_string}")
|
|
if len(failure_string) > 0:
|
|
print_acc(f" - {failure_string}")
|
|
print_acc("========================================")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# require at lease one config file
|
|
parser.add_argument(
|
|
'config_file_list',
|
|
nargs='+',
|
|
type=str,
|
|
help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially'
|
|
)
|
|
|
|
# flag to continue if failed job
|
|
parser.add_argument(
|
|
'-r', '--recover',
|
|
action='store_true',
|
|
help='Continue running additional jobs even if a job fails'
|
|
)
|
|
|
|
# flag to continue if failed job
|
|
parser.add_argument(
|
|
'-n', '--name',
|
|
type=str,
|
|
default=None,
|
|
help='Name to replace [name] tag in config file, useful for shared config file'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'-l', '--log',
|
|
type=str,
|
|
default=None,
|
|
help='Log file to write output to'
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.log is not None:
|
|
setup_log_to_file(args.log)
|
|
|
|
config_file_list = args.config_file_list
|
|
if len(config_file_list) == 0:
|
|
raise Exception("You must provide at least one config file")
|
|
|
|
jobs_completed = 0
|
|
jobs_failed = 0
|
|
|
|
if accelerator.is_main_process:
|
|
print_acc(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")
|
|
|
|
for config_file in config_file_list:
|
|
try:
|
|
job = get_job(config_file, args.name)
|
|
job.run()
|
|
job.cleanup()
|
|
jobs_completed += 1
|
|
except Exception as e:
|
|
print_acc(f"Error running job: {e}")
|
|
jobs_failed += 1
|
|
try:
|
|
job.process[0].on_error(e)
|
|
except Exception as e2:
|
|
print_acc(f"Error running on_error: {e2}")
|
|
if not args.recover:
|
|
print_end_message(jobs_completed, jobs_failed)
|
|
raise e
|
|
except KeyboardInterrupt as e:
|
|
try:
|
|
job.process[0].on_error(e)
|
|
except Exception as e2:
|
|
print_acc(f"Error running on_error: {e2}")
|
|
if not args.recover:
|
|
print_end_message(jobs_completed, jobs_failed)
|
|
raise e
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|