diff --git a/common/args.py b/common/args.py index 400af02..0a888af 100644 --- a/common/args.py +++ b/common/args.py @@ -1,8 +1,12 @@ """Argparser for overriding config values""" import argparse -from typing import get_origin, get_args, Union, List -from common.tabby_config import config +from typing import Any, Type, get_origin, get_args, Union, List +from inspect import get_annotations, isclass + +from pydantic import BaseModel + +from common.config_models import TabbyConfigModel def str_to_bool(value): @@ -33,7 +37,7 @@ def argument_with_auto(value): ) from ex -def map_pydantic_type_to_argparse(pydantic_type): +def map_pydantic_type_to_argparse(pydantic_type: Any): """ Maps Pydantic types to argparse compatible types. Handles special cases like Union and List. @@ -57,7 +61,7 @@ def map_pydantic_type_to_argparse(pydantic_type): return str -def add_field_to_group(group, field_name, field_type, field): +def add_field_to_group(group, field_name, field_type, field) -> None: """ Adds a Pydantic field to an argparse argument group. """ @@ -67,22 +71,24 @@ def add_field_to_group(group, field_name, field_type, field): group.add_argument(f"--{field_name}", type=arg_type, help=help_text) -def init_argparser(): +def init_argparser() -> argparse.ArgumentParser: """ Initializes an argparse parser based on a Pydantic config schema. """ parser = argparse.ArgumentParser(description="TabbyAPI server") + field_type: Union[Type[BaseModel], Any] + # Loop through each top-level field in the config - for field_name, field_type in config.__annotations__.items(): + for field_name, field_type in get_annotations(TabbyConfigModel).items(): group = parser.add_argument_group( field_name, description=f"Arguments for {field_name}" ) # Check if the field_type is a Pydantic model - if hasattr(field_type, "__annotations__"): - for sub_field_name, sub_field_type in field_type.__annotations__.items(): - field = field_type.__fields__[sub_field_name] + if isclass(field_type): + for sub_field_name, sub_field_type in get_annotations(field_type).items(): + field = field_type.model_fields[sub_field_name] add_field_to_group(group, sub_field_name, sub_field_type, field) else: # Handle cases where the field_type is not a Pydantic mode @@ -94,7 +100,9 @@ def init_argparser(): return parser -def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser): +def convert_args_to_dict( + args: argparse.Namespace, parser: argparse.ArgumentParser +) -> dict[str, dict[str, Any]]: """Broad conversion of surface level arg groups to dictionaries""" arg_groups = {}