mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
Tree: Update to cleanup globals
Use the module singleton pattern to share global state. This can also be a modified version of the Global Object Pattern. The main reason this pattern is used is for ease of use when handling global state rather than adding extra dependencies for a DI parameter. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -4,10 +4,11 @@ from loguru import logger
|
||||
|
||||
from common.utils import unwrap
|
||||
|
||||
# Global config dictionary constant
|
||||
GLOBAL_CONFIG: dict = {}
|
||||
|
||||
|
||||
def read_config_from_file(config_path: pathlib.Path):
|
||||
def from_file(config_path: pathlib.Path):
|
||||
"""Sets the global config from a given file path"""
|
||||
global GLOBAL_CONFIG
|
||||
|
||||
@@ -23,74 +24,77 @@ def read_config_from_file(config_path: pathlib.Path):
|
||||
GLOBAL_CONFIG = {}
|
||||
|
||||
|
||||
def override_config_from_args(args: dict):
|
||||
def from_args(args: dict):
|
||||
"""Overrides the config based on a dict representation of args"""
|
||||
|
||||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
if config_override:
|
||||
logger.info("Attempting to override config.yml from args.")
|
||||
read_config_from_file(pathlib.Path(config_override))
|
||||
from_file(pathlib.Path(config_override))
|
||||
return
|
||||
|
||||
# Network config
|
||||
network_override = args.get("network")
|
||||
if network_override:
|
||||
network_config = get_network_config()
|
||||
GLOBAL_CONFIG["network"] = {**network_config, **network_override}
|
||||
cur_network_config = network_config()
|
||||
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
|
||||
|
||||
# Model config
|
||||
model_override = args.get("model")
|
||||
if model_override:
|
||||
model_config = get_model_config()
|
||||
GLOBAL_CONFIG["model"] = {**model_config, **model_override}
|
||||
cur_model_config = model_config()
|
||||
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
|
||||
|
||||
# Logging config
|
||||
logging_override = args.get("logging")
|
||||
if logging_override:
|
||||
logging_config = get_gen_logging_config()
|
||||
# Generation Logging config
|
||||
gen_logging_override = args.get("logging")
|
||||
if gen_logging_override:
|
||||
cur_gen_logging_config = gen_logging_config()
|
||||
GLOBAL_CONFIG["logging"] = {
|
||||
**logging_config,
|
||||
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
|
||||
**cur_gen_logging_config,
|
||||
**{
|
||||
k.replace("log_", ""): gen_logging_override[k]
|
||||
for k in gen_logging_override
|
||||
},
|
||||
}
|
||||
|
||||
developer_override = args.get("developer")
|
||||
if developer_override:
|
||||
developer_config = get_developer_config()
|
||||
GLOBAL_CONFIG["developer"] = {**developer_config, **developer_override}
|
||||
cur_developer_config = developer_config()
|
||||
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
||||
|
||||
|
||||
def get_sampling_config():
|
||||
def sampling_config():
|
||||
"""Returns the sampling parameter config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("sampling"), {})
|
||||
|
||||
|
||||
def get_model_config():
|
||||
def model_config():
|
||||
"""Returns the model config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
|
||||
|
||||
def get_draft_model_config():
|
||||
def draft_model_config():
|
||||
"""Returns the draft model config from the global config"""
|
||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
return unwrap(model_config.get("draft"), {})
|
||||
|
||||
|
||||
def get_lora_config():
|
||||
def lora_config():
|
||||
"""Returns the lora config from the global config"""
|
||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
return unwrap(model_config.get("lora"), {})
|
||||
|
||||
|
||||
def get_network_config():
|
||||
def network_config():
|
||||
"""Returns the network config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("network"), {})
|
||||
|
||||
|
||||
def get_gen_logging_config():
|
||||
def gen_logging_config():
|
||||
"""Returns the generation logging config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("logging"), {})
|
||||
|
||||
|
||||
def get_developer_config():
|
||||
def developer_config():
|
||||
"""Returns the developer specific config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
||||
|
||||
Reference in New Issue
Block a user