diff --git a/config.py b/config.py index b6694b2..93b7ea0 100644 --- a/config.py +++ b/config.py @@ -8,6 +8,7 @@ logger = init_logger(__name__) GLOBAL_CONFIG: dict = {} + def read_config_from_file(config_path: pathlib.Path): """Sets the global config from a given file path""" global GLOBAL_CONFIG @@ -23,24 +24,29 @@ def read_config_from_file(config_path: pathlib.Path): ) GLOBAL_CONFIG = {} + def get_model_config(): """Returns the model config from the global config""" return unwrap(GLOBAL_CONFIG.get("model"), {}) + def get_draft_model_config(): """Returns the draft model config from the global config""" model_config = unwrap(GLOBAL_CONFIG.get("model"), {}) return unwrap(model_config.get("draft"), {}) + def get_lora_config(): """Returns the lora config from the global config""" model_config = unwrap(GLOBAL_CONFIG.get("model"), {}) return unwrap(model_config.get("lora"), {}) + def get_network_config(): """Returns the network config from the global config""" return unwrap(GLOBAL_CONFIG.get("network"), {}) + def get_gen_logging_config(): """Returns the generation logging config from the global config""" return unwrap(GLOBAL_CONFIG.get("logging"), {}) diff --git a/main.py b/main.py index 46baa47..5095bba 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ from config import ( get_model_config, get_draft_model_config, get_lora_config, - get_network_config + get_network_config, ) from generators import call_with_semaphore, generate_with_semaphore from model import ModelContainer diff --git a/model.py b/model.py index 2a3749b..60f7e9d 100644 --- a/model.py +++ b/model.py @@ -523,9 +523,7 @@ class ModelContainer: "installed ExLlamaV2 version." ) - if (unwrap(kwargs.get("top_a"), False)) and not hasattr ( - gen_settings, "top_a" - ): + if (unwrap(kwargs.get("top_a"), False)) and not hasattr(gen_settings, "top_a"): logger.warning( "Top-A is not supported by the currently " "installed ExLlamaV2 version."