diff --git a/common/config.py b/common/config.py index 2d1c980..b1b251b 100644 --- a/common/config.py +++ b/common/config.py @@ -1,10 +1,9 @@ import yaml import pathlib from loguru import logger -from mergedeep import merge, Strategy from typing import Any -from common.utils import unwrap +from common.utils import unwrap, merge_dicts # Global config dictionary constant GLOBAL_CONFIG: dict = {} @@ -21,7 +20,8 @@ def load(arguments: dict[str, Any]): from_args(arguments), ] - GLOBAL_CONFIG = merge({}, *configs, strategy=Strategy.REPLACE) + GLOBAL_CONFIG = merge_dicts(*configs) + def from_file(config_path: pathlib.Path) -> dict[str, Any]: """loads config from a given file path""" @@ -73,15 +73,16 @@ def from_environment() -> dict[str, Any]: # refactor the get_config functions -def get_config(config: dict[str, any], topic: str) -> callable : +def get_config(config: dict[str, any], topic: str) -> callable: return lambda: unwrap(config.get(topic), {}) + # each of these is a function -model_config = get_config(GLOBAL_CONFIG, "model") -sampling_config = get_config(GLOBAL_CONFIG, "sampling") -draft_model_config = get_config(model_config(), "draft") -lora_config = get_config(model_config(), "lora") -network_config = get_config(GLOBAL_CONFIG, "network") -logging_config = get_config(GLOBAL_CONFIG, "logging") -developer_config = get_config(GLOBAL_CONFIG, "developer") -embeddings_config = get_config(GLOBAL_CONFIG, "embeddings") +model_config = get_config(GLOBAL_CONFIG, "model") +sampling_config = get_config(GLOBAL_CONFIG, "sampling") +draft_model_config = get_config(model_config(), "draft") +lora_config = get_config(model_config(), "lora") +network_config = get_config(GLOBAL_CONFIG, "network") +logging_config = get_config(GLOBAL_CONFIG, "logging") +developer_config = get_config(GLOBAL_CONFIG, "developer") +embeddings_config = get_config(GLOBAL_CONFIG, "embeddings") diff --git a/common/utils.py b/common/utils.py index b120022..5133ed8 100644 --- a/common/utils.py +++ b/common/utils.py @@ -20,6 +20,23 @@ def prune_dict(input_dict): return {k: v for k, v in input_dict.items() if v is not None} +def merge_dict(dict1, dict2): + """Merge 2 dictionaries""" + for key, value in dict2.items(): + if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict): + merge_dict(dict1[key], value) + else: + dict1[key] = value + return dict1 + + +def merge_dicts(*dicts): + """Merge an arbitrary amount of dictionaries""" + result = {} + for dictionary in dicts: + result = merge_dict(result, dictionary) + + def flat_map(input_list): """Flattens a list of lists into a single list.""" diff --git a/pyproject.toml b/pyproject.toml index 89c9661..b9e80fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "huggingface_hub", "psutil", "httptools>=0.5.0", - "mergedeep", # Improved asyncio loops "uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",