mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
91 lines
2.7 KiB
Python
91 lines
2.7 KiB
Python
import os
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
import sys
|
|
from typing import Union, OrderedDict
|
|
from dotenv import load_dotenv
|
|
# Load the .env file if it exists
|
|
load_dotenv()
|
|
|
|
sys.path.insert(0, os.getcwd())
|
|
# must come before ANY torch or fastai imports
|
|
# import toolkit.cuda_malloc
|
|
|
|
# turn off diffusers telemetry until I can figure out how to make it opt-in
|
|
os.environ['DISABLE_TELEMETRY'] = 'YES'
|
|
|
|
# check if we have DEBUG_TOOLKIT in env
|
|
if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
|
|
# set torch to trace mode
|
|
import torch
|
|
torch.autograd.set_detect_anomaly(True)
|
|
import argparse
|
|
from toolkit.job import get_job
|
|
|
|
|
|
def print_end_message(jobs_completed, jobs_failed):
|
|
failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else ""
|
|
completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}"
|
|
|
|
print("")
|
|
print("========================================")
|
|
print("Result:")
|
|
if len(completed_string) > 0:
|
|
print(f" - {completed_string}")
|
|
if len(failure_string) > 0:
|
|
print(f" - {failure_string}")
|
|
print("========================================")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# require at lease one config file
|
|
parser.add_argument(
|
|
'config_file_list',
|
|
nargs='+',
|
|
type=str,
|
|
help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially'
|
|
)
|
|
|
|
# flag to continue if failed job
|
|
parser.add_argument(
|
|
'-r', '--recover',
|
|
action='store_true',
|
|
help='Continue running additional jobs even if a job fails'
|
|
)
|
|
|
|
# flag to continue if failed job
|
|
parser.add_argument(
|
|
'-n', '--name',
|
|
type=str,
|
|
default=None,
|
|
help='Name to replace [name] tag in config file, useful for shared config file'
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
config_file_list = args.config_file_list
|
|
if len(config_file_list) == 0:
|
|
raise Exception("You must provide at least one config file")
|
|
|
|
jobs_completed = 0
|
|
jobs_failed = 0
|
|
|
|
print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")
|
|
|
|
for config_file in config_file_list:
|
|
try:
|
|
job = get_job(config_file, args.name)
|
|
job.run()
|
|
job.cleanup()
|
|
jobs_completed += 1
|
|
except Exception as e:
|
|
print(f"Error running job: {e}")
|
|
jobs_failed += 1
|
|
if not args.recover:
|
|
print_end_message(jobs_completed, jobs_failed)
|
|
raise e
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|