mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added a way to run as a library by passing job dict
This commit is contained in:
13
run.py
13
run.py
@@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Union, OrderedDict
|
||||||
|
|
||||||
sys.path.insert(0, os.getcwd())
|
sys.path.insert(0, os.getcwd())
|
||||||
import argparse
|
import argparse
|
||||||
from toolkit.job import get_job
|
from toolkit.job import get_job
|
||||||
@@ -19,6 +21,17 @@ def print_end_message(jobs_completed, jobs_failed):
|
|||||||
print("========================================")
|
print("========================================")
|
||||||
|
|
||||||
|
|
||||||
|
def run_job(
|
||||||
|
config: Union[str, dict, OrderedDict],
|
||||||
|
name=None
|
||||||
|
):
|
||||||
|
from toolkit.job import get_job
|
||||||
|
|
||||||
|
job = get_job(config, name)
|
||||||
|
job.run()
|
||||||
|
job.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import oyaml as yaml
|
import oyaml as yaml
|
||||||
import re
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@@ -47,7 +49,17 @@ fixed_loader.add_implicit_resolver(
|
|||||||
list(u'-+0123456789.'))
|
list(u'-+0123456789.'))
|
||||||
|
|
||||||
|
|
||||||
def get_config(config_file_path, name=None):
|
def get_config(
|
||||||
|
config_file_path_or_dict: Union[str, dict, OrderedDict],
|
||||||
|
name=None
|
||||||
|
):
|
||||||
|
# if we got a dict, process it and return it
|
||||||
|
if isinstance(config_file_path_or_dict, dict) or isinstance(config_file_path_or_dict, OrderedDict):
|
||||||
|
config = config_file_path_or_dict
|
||||||
|
return preprocess_config(config, name)
|
||||||
|
|
||||||
|
config_file_path = config_file_path_or_dict
|
||||||
|
|
||||||
# first check if it is in the config folder
|
# first check if it is in the config folder
|
||||||
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
|
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
|
||||||
# see if it is in the config folder with any of the possible extensions if it doesnt have one
|
# see if it is in the config folder with any of the possible extensions if it doesnt have one
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
|
from typing import Union, OrderedDict
|
||||||
|
|
||||||
from toolkit.config import get_config
|
from toolkit.config import get_config
|
||||||
|
|
||||||
|
|
||||||
def get_job(config_path, name=None):
|
def get_job(
|
||||||
|
config_path: Union[str, dict, OrderedDict],
|
||||||
|
name=None
|
||||||
|
):
|
||||||
config = get_config(config_path, name)
|
config = get_config(config_path, name)
|
||||||
if not config['job']:
|
if not config['job']:
|
||||||
raise ValueError('config file is invalid. Missing "job" key')
|
raise ValueError('config file is invalid. Missing "job" key')
|
||||||
|
|||||||
Reference in New Issue
Block a user