Tree: Update to cleanup globals

Use the module singleton pattern to share global state. This can also
be a modified version of the Global Object Pattern. The main reason
this pattern is used is for ease of use when handling global state
rather than adding extra dependencies for a DI parameter.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-09 22:31:47 -05:00
committed by Brian Dashore
parent b373b25235
commit 5a2de30066
7 changed files with 68 additions and 82 deletions

View File

@@ -4,10 +4,11 @@ from loguru import logger
from common.utils import unwrap
# Global config dictionary constant
GLOBAL_CONFIG: dict = {}
def read_config_from_file(config_path: pathlib.Path):
def from_file(config_path: pathlib.Path):
"""Sets the global config from a given file path"""
global GLOBAL_CONFIG
@@ -23,74 +24,77 @@ def read_config_from_file(config_path: pathlib.Path):
GLOBAL_CONFIG = {}
def override_config_from_args(args: dict):
def from_args(args: dict):
"""Overrides the config based on a dict representation of args"""
config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Attempting to override config.yml from args.")
read_config_from_file(pathlib.Path(config_override))
from_file(pathlib.Path(config_override))
return
# Network config
network_override = args.get("network")
if network_override:
network_config = get_network_config()
GLOBAL_CONFIG["network"] = {**network_config, **network_override}
cur_network_config = network_config()
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
# Model config
model_override = args.get("model")
if model_override:
model_config = get_model_config()
GLOBAL_CONFIG["model"] = {**model_config, **model_override}
cur_model_config = model_config()
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
# Logging config
logging_override = args.get("logging")
if logging_override:
logging_config = get_gen_logging_config()
# Generation Logging config
gen_logging_override = args.get("logging")
if gen_logging_override:
cur_gen_logging_config = gen_logging_config()
GLOBAL_CONFIG["logging"] = {
**logging_config,
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
**cur_gen_logging_config,
**{
k.replace("log_", ""): gen_logging_override[k]
for k in gen_logging_override
},
}
developer_override = args.get("developer")
if developer_override:
developer_config = get_developer_config()
GLOBAL_CONFIG["developer"] = {**developer_config, **developer_override}
cur_developer_config = developer_config()
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
def get_sampling_config():
def sampling_config():
"""Returns the sampling parameter config from the global config"""
return unwrap(GLOBAL_CONFIG.get("sampling"), {})
def get_model_config():
def model_config():
"""Returns the model config from the global config"""
return unwrap(GLOBAL_CONFIG.get("model"), {})
def get_draft_model_config():
def draft_model_config():
"""Returns the draft model config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("draft"), {})
def get_lora_config():
def lora_config():
"""Returns the lora config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("lora"), {})
def get_network_config():
def network_config():
"""Returns the network config from the global config"""
return unwrap(GLOBAL_CONFIG.get("network"), {})
def get_gen_logging_config():
def gen_logging_config():
"""Returns the generation logging config from the global config"""
return unwrap(GLOBAL_CONFIG.get("logging"), {})
def get_developer_config():
def developer_config():
"""Returns the developer specific config from the global config"""
return unwrap(GLOBAL_CONFIG.get("developer"), {})