diff --git a/common/args.py b/common/args.py index 7d2427f..7c09646 100644 --- a/common/args.py +++ b/common/args.py @@ -4,7 +4,7 @@ import argparse from pydantic import BaseModel from common.config_models import TabbyConfigModel -from common.utils import is_list_type +from common.utils import is_list_type, unwrap_optional def add_field_to_group(group, field_name, field_type, field) -> None: @@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser: # Loop through each top-level field in the config for field_name, field_info in TabbyConfigModel.model_fields.items(): - field_type = field_info.annotation + field_type = unwrap_optional(field_info.annotation) group = parser.add_argument_group( field_name, description=f"Arguments for {field_name}" ) diff --git a/common/config_models.py b/common/config_models.py index 2c888f8..653280b 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -439,22 +439,32 @@ class DeveloperConfig(BaseConfigModel): class TabbyConfigModel(BaseModel): """Base model for a TabbyConfig.""" - config: ConfigOverrideConfig = Field( + config: Optional[ConfigOverrideConfig] = Field( default_factory=ConfigOverrideConfig.model_construct ) - network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct) - logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct) - model: ModelConfig = Field(default_factory=ModelConfig.model_construct) - draft_model: DraftModelConfig = Field( + network: Optional[NetworkConfig] = Field( + default_factory=NetworkConfig.model_construct + ) + logging: Optional[LoggingConfig] = Field( + default_factory=LoggingConfig.model_construct + ) + model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct) + draft_model: Optional[DraftModelConfig] = Field( default_factory=DraftModelConfig.model_construct ) - lora: LoraConfig = Field(default_factory=LoraConfig.model_construct) - embeddings: EmbeddingsConfig = Field( + lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct) + embeddings: Optional[EmbeddingsConfig] = Field( default_factory=EmbeddingsConfig.model_construct ) - sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct) - developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct) - actions: UtilityActions = Field(default_factory=UtilityActions.model_construct) + sampling: Optional[SamplingConfig] = Field( + default_factory=SamplingConfig.model_construct + ) + developer: Optional[DeveloperConfig] = Field( + default_factory=DeveloperConfig.model_construct + ) + actions: Optional[UtilityActions] = Field( + default_factory=UtilityActions.model_construct + ) model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) diff --git a/common/tabby_config.py b/common/tabby_config.py index 2f0481d..1dacac0 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -2,10 +2,10 @@ import yaml import pathlib from loguru import logger from typing import Optional -from os import getenv +from os import getenv, replace from common.utils import unwrap, merge_dicts -from common.config_models import TabbyConfigModel +from common.config_models import TabbyConfigModel, generate_config_file class TabbyConfig(TabbyConfigModel): @@ -46,10 +46,25 @@ class TabbyConfig(TabbyConfigModel): def _from_file(self, config_path: pathlib.Path): """loads config from a given file path""" + legacy = False + cfg = {} + # try loading from file try: with open(str(config_path.resolve()), "r", encoding="utf8") as config_file: - return unwrap(yaml.safe_load(config_file), {}) + cfg = yaml.safe_load(config_file) + + # FIXME: remove legacy config mapper + # load legacy config files + model = cfg.get("model", {}) + + if model.get("draft"): + legacy = True + cfg["draft"] = model["draft"] + if model.get("lora"): + legacy = True + cfg["lora"] = model["lora"] + except FileNotFoundError: logger.info(f"The '{config_path.name}' file cannot be found") except Exception as exc: @@ -58,8 +73,21 @@ class TabbyConfig(TabbyConfigModel): f"the following error:\n\n{exc}" ) - # if no config file was loaded - return {} + if legacy: + logger.warning( + "legacy config.yml files are deprecated" + "Please upadte to the new version" + "Attempting auto migrationy" + ) + new_cfg = TabbyConfigModel.model_validate(cfg) + + try: + replace(config_path, f"{config_path}.bak") + generate_config_file(model=new_cfg, filename=config_path) + except Exception as e: + logger.error(f"Auto migration failed: {e}") + + return unwrap(cfg, {}) def _from_args(self, args: dict): """loads config from the provided arguments""" diff --git a/common/utils.py b/common/utils.py index d933fb6..acc0fc9 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,6 +1,7 @@ """Common utility functions""" -from typing import get_args, get_origin +from types import NoneType +from typing import Optional, Type, Union, get_args, get_origin def unwrap(wrapped, default=None): @@ -47,7 +48,7 @@ def flat_map(input_list): return [item for sublist in input_list for item in sublist] -def is_list_type(type_hint): +def is_list_type(type_hint) -> bool: """Checks if a type contains a list.""" if get_origin(type_hint) is list: @@ -59,3 +60,16 @@ def is_list_type(type_hint): return any(is_list_type(arg) for arg in type_args) return False + + +def unwrap_optional(type_hint) -> Type: + """unwrap Optional[type] annotations""" + + if get_origin(type_hint) is Union: + args = get_args(type_hint) + if NoneType in args: + for arg in args: + if arg is not NoneType: + return arg + + return type_hint