mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-29 02:31:48 +00:00
refactor config loading
- improve DRY - alter logging - allow extensibility - add foundation for environment variables as config
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
import yaml
|
import yaml
|
||||||
import pathlib
|
import pathlib
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from mergedeep import merge, Strategy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
|
|
||||||
@@ -8,61 +10,66 @@ from common.utils import unwrap
|
|||||||
GLOBAL_CONFIG: dict = {}
|
GLOBAL_CONFIG: dict = {}
|
||||||
|
|
||||||
|
|
||||||
def from_file(config_path: pathlib.Path):
|
def load(arguments: dict[str, Any]):
|
||||||
"""Sets the global config from a given file path"""
|
"""load the global application config"""
|
||||||
global GLOBAL_CONFIG
|
global GLOBAL_CONFIG
|
||||||
|
|
||||||
|
# config is applied in order of items in the list
|
||||||
|
configs = [
|
||||||
|
from_file(pathlib.Path("config.yml")),
|
||||||
|
from_environment(),
|
||||||
|
from_args(arguments),
|
||||||
|
]
|
||||||
|
|
||||||
|
GLOBAL_CONFIG = merge({}, *configs, strategy=Strategy.REPLACE)
|
||||||
|
|
||||||
|
def from_file(config_path: pathlib.Path) -> dict[str, Any]:
|
||||||
|
"""loads config from a given file path"""
|
||||||
|
|
||||||
|
# try loading from file
|
||||||
try:
|
try:
|
||||||
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
||||||
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {})
|
return unwrap(yaml.safe_load(config_file), {})
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.info("The config.yml file cannot be found")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
"The YAML config couldn't load because of the following error: "
|
f"The YAML config couldn't load because of the following error:\n\n{exc}"
|
||||||
f"\n\n{exc}"
|
|
||||||
"\n\nTabbyAPI will start anyway and not parse this config file."
|
|
||||||
)
|
)
|
||||||
GLOBAL_CONFIG = {}
|
|
||||||
|
# if no config file was loaded
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def from_args(args: dict):
|
def from_args(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Overrides the config based on a dict representation of args"""
|
"""loads config from the provided arguments"""
|
||||||
|
config = {}
|
||||||
|
|
||||||
config_override = unwrap(args.get("options", {}).get("config"))
|
config_override = unwrap(args.get("options", {}).get("config"))
|
||||||
if config_override:
|
if config_override:
|
||||||
logger.info("Attempting to override config.yml from args.")
|
logger.info("Config file override detected in args.")
|
||||||
from_file(pathlib.Path(config_override))
|
config = from_file(pathlib.Path(config_override))
|
||||||
return
|
return config # Return early if loading from file
|
||||||
|
|
||||||
# Network config
|
for key in ["network", "model", "logging", "developer", "embeddings"]:
|
||||||
network_override = args.get("network")
|
override = args.get(key)
|
||||||
if network_override:
|
if override:
|
||||||
cur_network_config = network_config()
|
if key == "logging":
|
||||||
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
|
# Strip the "log_" prefix from logging keys if present
|
||||||
|
override = {k.replace("log_", ""): v for k, v in override.items()}
|
||||||
|
config[key] = override
|
||||||
|
|
||||||
# Model config
|
return config
|
||||||
model_override = args.get("model")
|
|
||||||
if model_override:
|
|
||||||
cur_model_config = model_config()
|
|
||||||
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
|
|
||||||
|
|
||||||
# Generation Logging config
|
|
||||||
logging_override = args.get("logging")
|
|
||||||
if logging_override:
|
|
||||||
cur_logging_config = logging_config()
|
|
||||||
GLOBAL_CONFIG["logging"] = {
|
|
||||||
**cur_logging_config,
|
|
||||||
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
|
|
||||||
}
|
|
||||||
|
|
||||||
developer_override = args.get("developer")
|
def from_environment() -> dict[str, Any]:
|
||||||
if developer_override:
|
"""loads configuration from environment variables"""
|
||||||
cur_developer_config = developer_config()
|
|
||||||
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
|
||||||
|
|
||||||
embeddings_override = args.get("embeddings")
|
# TODO: load config from environment variables
|
||||||
if embeddings_override:
|
# this means that we can have host default to 0.0.0.0 in docker for example
|
||||||
cur_embeddings_config = embeddings_config()
|
# this would also mean that docker containers no longer require a non
|
||||||
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
|
# default config file to be used
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def sampling_config():
|
def sampling_config():
|
||||||
|
|||||||
6
main.py
6
main.py
@@ -110,15 +110,13 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
# Load from YAML config
|
|
||||||
config.from_file(pathlib.Path("config.yml"))
|
|
||||||
|
|
||||||
# Parse and override config from args
|
# Parse and override config from args
|
||||||
if arguments is None:
|
if arguments is None:
|
||||||
parser = init_argparser()
|
parser = init_argparser()
|
||||||
arguments = convert_args_to_dict(parser.parse_args(), parser)
|
arguments = convert_args_to_dict(parser.parse_args(), parser)
|
||||||
|
|
||||||
config.from_args(arguments)
|
# load config
|
||||||
|
config.load(arguments)
|
||||||
|
|
||||||
if do_export_openapi:
|
if do_export_openapi:
|
||||||
openapi_json = export_openapi()
|
openapi_json = export_openapi()
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ dependencies = [
|
|||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
"psutil",
|
"psutil",
|
||||||
"httptools>=0.5.0",
|
"httptools>=0.5.0",
|
||||||
|
"mergedeep",
|
||||||
|
|
||||||
# Improved asyncio loops
|
# Improved asyncio loops
|
||||||
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||||
|
|||||||
Reference in New Issue
Block a user