Files
tabbyAPI/common/tabby_config.py
kingbri b9e5693c1b API + Model: Apply config.yml defaults for all load paths
There are two ways to load a model:
1. Via the load endpoint
2. Inline with a completion

The defaults were not applying on the inline load, so rewrite to fix
that. However, while doing this, set up a defaults dictionary rather
than comparing it at runtime and remove the pydantic default lambda
on all the model load fields.

This makes the code cleaner and establishes a clear config tree for
loading models.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-09-10 23:35:35 -04:00

103 lines
3.6 KiB
Python

import yaml
import pathlib
from loguru import logger
from typing import Optional
from common.utils import unwrap, merge_dicts
class TabbyConfig:
"""Common config class for TabbyAPI. Loaded into sub-dictionaries from YAML file."""
# Sub-blocks of yaml
network: dict = {}
logging: dict = {}
model: dict = {}
draft_model: dict = {}
lora: dict = {}
sampling: dict = {}
developer: dict = {}
embeddings: dict = {}
# Persistent defaults
model_defaults: dict = {}
def load(self, arguments: Optional[dict] = None):
"""Synchronously loads the global application config"""
# config is applied in order of items in the list
configs = [
self._from_file(pathlib.Path("config.yml")),
self._from_args(unwrap(arguments, {})),
]
merged_config = merge_dicts(*configs)
self.network = unwrap(merged_config.get("network"), {})
self.logging = unwrap(merged_config.get("logging"), {})
self.model = unwrap(merged_config.get("model"), {})
self.draft_model = unwrap(self.model.get("draft"), {})
self.lora = unwrap(self.model.get("lora"), {})
self.sampling = unwrap(merged_config.get("sampling"), {})
self.developer = unwrap(merged_config.get("developer"), {})
self.embeddings = unwrap(merged_config.get("embeddings"), {})
# Set model defaults dict once to prevent on-demand reconstruction
default_keys = unwrap(self.model.get("use_as_default"), [])
for key in default_keys:
if key in self.model:
self.model_defaults[key] = config.model[key]
elif key in self.draft_model:
self.model_defaults[key] = config.draft_model[key]
def _from_file(self, config_path: pathlib.Path):
"""loads config from a given file path"""
# try loading from file
try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
return unwrap(yaml.safe_load(config_file), {})
except FileNotFoundError:
logger.info(f"The '{config_path.name}' file cannot be found")
except Exception as exc:
logger.error(
f"The YAML config from '{config_path.name}' couldn't load because of "
f"the following error:\n\n{exc}"
)
# if no config file was loaded
return {}
def _from_args(self, args: dict):
"""loads config from the provided arguments"""
config = {}
config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Config file override detected in args.")
config = self._from_file(pathlib.Path(config_override))
return config # Return early if loading from file
for key in ["network", "model", "logging", "developer", "embeddings"]:
override = args.get(key)
if override:
if key == "logging":
# Strip the "log_" prefix from logging keys if present
override = {k.replace("log_", ""): v for k, v in override.items()}
config[key] = override
return config
def _from_environment(self):
"""loads configuration from environment variables"""
# TODO: load config from environment variables
# this means that we can have host default to 0.0.0.0 in docker for example
# this would also mean that docker containers no longer require a non
# default config file to be used
pass
# Create an empty instance of the config class
config: TabbyConfig = TabbyConfig()