refactor config loading

- improve DRY
- alter logging
- allow extensibility
- add foundation for environment variables as config
This commit is contained in:
Jake
2024-09-04 12:22:49 +01:00
parent 8854269121
commit fa6404a95a
3 changed files with 48 additions and 42 deletions

View File

@@ -1,6 +1,8 @@
import yaml import yaml
import pathlib import pathlib
from loguru import logger from loguru import logger
from mergedeep import merge, Strategy
from typing import Any
from common.utils import unwrap from common.utils import unwrap
@@ -8,61 +10,66 @@ from common.utils import unwrap
GLOBAL_CONFIG: dict = {} GLOBAL_CONFIG: dict = {}
def from_file(config_path: pathlib.Path): def load(arguments: dict[str, Any]):
"""Sets the global config from a given file path""" """load the global application config"""
global GLOBAL_CONFIG global GLOBAL_CONFIG
# config is applied in order of items in the list
configs = [
from_file(pathlib.Path("config.yml")),
from_environment(),
from_args(arguments),
]
GLOBAL_CONFIG = merge({}, *configs, strategy=Strategy.REPLACE)
def from_file(config_path: pathlib.Path) -> dict[str, Any]:
"""loads config from a given file path"""
# try loading from file
try: try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file: with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {}) return unwrap(yaml.safe_load(config_file), {})
except FileNotFoundError:
logger.info("The config.yml file cannot be found")
except Exception as exc: except Exception as exc:
logger.error( logger.error(
"The YAML config couldn't load because of the following error: " f"The YAML config couldn't load because of the following error:\n\n{exc}"
f"\n\n{exc}"
"\n\nTabbyAPI will start anyway and not parse this config file."
) )
GLOBAL_CONFIG = {}
# if no config file was loaded
return {}
def from_args(args: dict): def from_args(args: dict[str, Any]) -> dict[str, Any]:
"""Overrides the config based on a dict representation of args""" """loads config from the provided arguments"""
config = {}
config_override = unwrap(args.get("options", {}).get("config")) config_override = unwrap(args.get("options", {}).get("config"))
if config_override: if config_override:
logger.info("Attempting to override config.yml from args.") logger.info("Config file override detected in args.")
from_file(pathlib.Path(config_override)) config = from_file(pathlib.Path(config_override))
return return config # Return early if loading from file
# Network config for key in ["network", "model", "logging", "developer", "embeddings"]:
network_override = args.get("network") override = args.get(key)
if network_override: if override:
cur_network_config = network_config() if key == "logging":
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override} # Strip the "log_" prefix from logging keys if present
override = {k.replace("log_", ""): v for k, v in override.items()}
config[key] = override
# Model config return config
model_override = args.get("model")
if model_override:
cur_model_config = model_config()
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
# Generation Logging config
logging_override = args.get("logging")
if logging_override:
cur_logging_config = logging_config()
GLOBAL_CONFIG["logging"] = {
**cur_logging_config,
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
}
developer_override = args.get("developer") def from_environment() -> dict[str, Any]:
if developer_override: """loads configuration from environment variables"""
cur_developer_config = developer_config()
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
embeddings_override = args.get("embeddings") # TODO: load config from environment variables
if embeddings_override: # this means that we can have host default to 0.0.0.0 in docker for example
cur_embeddings_config = embeddings_config() # this would also mean that docker containers no longer require a non
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override} # default config file to be used
return {}
def sampling_config(): def sampling_config():

View File

@@ -110,15 +110,13 @@ def entrypoint(arguments: Optional[dict] = None):
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
# Load from YAML config
config.from_file(pathlib.Path("config.yml"))
# Parse and override config from args # Parse and override config from args
if arguments is None: if arguments is None:
parser = init_argparser() parser = init_argparser()
arguments = convert_args_to_dict(parser.parse_args(), parser) arguments = convert_args_to_dict(parser.parse_args(), parser)
config.from_args(arguments) # load config
config.load(arguments)
if do_export_openapi: if do_export_openapi:
openapi_json = export_openapi() openapi_json = export_openapi()

View File

@@ -32,6 +32,7 @@ dependencies = [
"huggingface_hub", "huggingface_hub",
"psutil", "psutil",
"httptools>=0.5.0", "httptools>=0.5.0",
"mergedeep",
# Improved asyncio loops # Improved asyncio loops
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'", "uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",