From 0d7459191c77a7e12795a0b9e702cfaa66df76b5 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Wed, 11 Sep 2024 16:13:31 +0100 Subject: [PATCH] fix arg parser for dict types --- common/args.py | 71 ++++++++++++++++++++++++++++-------------- common/tabby_config.py | 1 - 2 files changed, 48 insertions(+), 24 deletions(-) diff --git a/common/args.py b/common/args.py index 42f795a..f737a0c 100644 --- a/common/args.py +++ b/common/args.py @@ -34,39 +34,64 @@ def argument_with_auto(value): ) from ex +def map_pydantic_type_to_argparse(pydantic_type): + """ + Maps Pydantic types to argparse compatible types. + Handles special cases like Union and List. + """ + origin = get_origin(pydantic_type) + + # Handle optional types + if origin is Union: + # Filter out NoneType + pydantic_type = next(t for t in get_args(pydantic_type) if t is not type(None)) + + elif origin is List: + pydantic_type = get_args(pydantic_type)[0] # Get the list item type + + # Map basic types (int, float, str, bool) + if isinstance(pydantic_type, type) and issubclass( + pydantic_type, (int, float, str, bool) + ): + return pydantic_type + + return str + + +def add_field_to_group(group, field_name, field_type, field): + """ + Adds a Pydantic field to an argparse argument group. + """ + arg_type = map_pydantic_type_to_argparse(field_type) + help_text = field.description if field.description else "No description available" + + group.add_argument(f"--{field_name}", type=arg_type, help=help_text) + + def init_argparser(): + """ + Initializes an argparse parser based on a Pydantic config schema. + """ parser = argparse.ArgumentParser(description="TabbyAPI server") + # Loop through each top-level field in the config for field_name, field_type in config.__annotations__.items(): group = parser.add_argument_group( field_name, description=f"Arguments for {field_name}" ) - # Loop through each field in the sub-model - for sub_field_name, sub_field_type in field_type.__annotations__.items(): - field = field_type.__fields__[sub_field_name] - help_text = ( - field.description if field.description else "No description available" + # Check if the field_type is a Pydantic model + if hasattr(field_type, "__annotations__"): + for sub_field_name, sub_field_type in field_type.__annotations__.items(): + field = field_type.__fields__[sub_field_name] + add_field_to_group(group, sub_field_name, sub_field_type, field) + else: + # Handle cases where the field_type is not a Pydantic mode + arg_type = map_pydantic_type_to_argparse(field_type) + group.add_argument( + f"--{field_name}", type=arg_type, help=f"Argument for {field_name}" ) - origin = get_origin(sub_field_type) - if origin is Union: - sub_field_type = next( - t for t in get_args(sub_field_type) if t is not type(None) - ) - elif origin is List: - sub_field_type = get_args(sub_field_type)[0] - - # Map Pydantic types to argparse types - if isinstance(sub_field_type, type) and issubclass( - sub_field_type, (int, float, str, bool) - ): - arg_type = sub_field_type - else: - arg_type = str # Default to string for unknown types - - group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text) - return parser diff --git a/common/tabby_config.py b/common/tabby_config.py index cd7cb14..d571319 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -10,7 +10,6 @@ import common.config_models class TabbyConfig(tabby_config_model): - # Persistent defaults # TODO: make this pydantic? model_defaults: dict = {}