diff --git a/common/args.py b/common/args.py index 0a888af..b991103 100644 --- a/common/args.py +++ b/common/args.py @@ -1,8 +1,7 @@ """Argparser for overriding config values""" import argparse -from typing import Any, Type, get_origin, get_args, Union, List -from inspect import get_annotations, isclass +from typing import Any, get_origin, get_args, Union, List from pydantic import BaseModel @@ -42,6 +41,7 @@ def map_pydantic_type_to_argparse(pydantic_type: Any): Maps Pydantic types to argparse compatible types. Handles special cases like Union and List. """ + origin = get_origin(pydantic_type) # Handle optional types @@ -65,6 +65,7 @@ def add_field_to_group(group, field_name, field_type, field) -> None: """ Adds a Pydantic field to an argparse argument group. """ + arg_type = map_pydantic_type_to_argparse(field_type) help_text = field.description if field.description else "No description available" @@ -75,23 +76,26 @@ def init_argparser() -> argparse.ArgumentParser: """ Initializes an argparse parser based on a Pydantic config schema. """ + parser = argparse.ArgumentParser(description="TabbyAPI server") - field_type: Union[Type[BaseModel], Any] - # Loop through each top-level field in the config - for field_name, field_type in get_annotations(TabbyConfigModel).items(): + for field_name, field_info in TabbyConfigModel.model_fields.items(): + field_type = field_info.annotation group = parser.add_argument_group( field_name, description=f"Arguments for {field_name}" ) # Check if the field_type is a Pydantic model - if isclass(field_type): - for sub_field_name, sub_field_type in get_annotations(field_type).items(): - field = field_type.model_fields[sub_field_name] - add_field_to_group(group, sub_field_name, sub_field_type, field) + if issubclass(field_type, BaseModel): + for sub_field_name, sub_field_info in field_type.model_fields.items(): + sub_field_name = sub_field_name.replace("_", "-") + sub_field_type = sub_field_info.annotation + add_field_to_group( + group, sub_field_name, sub_field_type, sub_field_info + ) else: - # Handle cases where the field_type is not a Pydantic mode + field_name = field_name.replace("_", "-") arg_type = map_pydantic_type_to_argparse(field_type) group.add_argument( f"--{field_name}", type=arg_type, help=f"Argument for {field_name}" diff --git a/common/config_models.py b/common/config_models.py index 5e5b5a2..ced18f9 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -4,7 +4,7 @@ from typing import List, Optional, Union from common.utils import unwrap -class ConfigConfig(BaseModel): +class ConfigOverrideConfig(BaseModel): config: Optional[str] = Field( None, description=("Path to an overriding config.yml file") ) @@ -279,7 +279,9 @@ class EmbeddingsConfig(BaseModel): class TabbyConfigModel(BaseModel): - config: ConfigConfig = Field(default_factory=ConfigConfig.model_construct) + config: ConfigOverrideConfig = Field( + default_factory=ConfigOverrideConfig.model_construct + ) network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct) logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct) model: ModelConfig = Field(default_factory=ModelConfig.model_construct)