mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-20 04:13:57 +00:00
Added a way to run as a library by passing job dict
This commit is contained in:
13
run.py
13
run.py
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Union, OrderedDict
|
||||
|
||||
sys.path.insert(0, os.getcwd())
|
||||
import argparse
|
||||
from toolkit.job import get_job
|
||||
@@ -19,6 +21,17 @@ def print_end_message(jobs_completed, jobs_failed):
|
||||
print("========================================")
|
||||
|
||||
|
||||
def run_job(
|
||||
config: Union[str, dict, OrderedDict],
|
||||
name=None
|
||||
):
|
||||
from toolkit.job import get_job
|
||||
|
||||
job = get_job(config, name)
|
||||
job.run()
|
||||
job.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
import oyaml as yaml
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
@@ -47,7 +49,17 @@ fixed_loader.add_implicit_resolver(
|
||||
list(u'-+0123456789.'))
|
||||
|
||||
|
||||
def get_config(config_file_path, name=None):
|
||||
def get_config(
|
||||
config_file_path_or_dict: Union[str, dict, OrderedDict],
|
||||
name=None
|
||||
):
|
||||
# if we got a dict, process it and return it
|
||||
if isinstance(config_file_path_or_dict, dict) or isinstance(config_file_path_or_dict, OrderedDict):
|
||||
config = config_file_path_or_dict
|
||||
return preprocess_config(config, name)
|
||||
|
||||
config_file_path = config_file_path_or_dict
|
||||
|
||||
# first check if it is in the config folder
|
||||
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
|
||||
# see if it is in the config folder with any of the possible extensions if it doesnt have one
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from typing import Union, OrderedDict
|
||||
|
||||
from toolkit.config import get_config
|
||||
|
||||
|
||||
def get_job(config_path, name=None):
|
||||
def get_job(
|
||||
config_path: Union[str, dict, OrderedDict],
|
||||
name=None
|
||||
):
|
||||
config = get_config(config_path, name)
|
||||
if not config['job']:
|
||||
raise ValueError('config file is invalid. Missing "job" key')
|
||||
|
||||
Reference in New Issue
Block a user