diff --git a/common/args.py b/common/args.py index 4755c39..ffecaaa 100644 --- a/common/args.py +++ b/common/args.py @@ -1,7 +1,9 @@ """Argparser for overriding config values""" import argparse - +from typing import get_origin, get_args, Optional, Union, List +from pydantic import BaseModel +from common.tabby_config import config def str_to_bool(value): """Converts a string into a boolean value""" @@ -32,24 +34,40 @@ def argument_with_auto(value): def init_argparser(): - """Creates an argument parser that any function can use""" + parser = argparse.ArgumentParser(description="TabbyAPI server") - parser = argparse.ArgumentParser( - epilog="NOTE: These args serve to override parts of the config. " - + "It's highly recommended to edit config.yml for all options and " - + "better descriptions!" - ) - add_network_args(parser) - add_model_args(parser) - add_embeddings_args(parser) - add_logging_args(parser) - add_developer_args(parser) - add_sampling_args(parser) - add_config_args(parser) + # Loop through the fields in the top-level model (ModelX in this case) + for field_name, field_type in config.__annotations__.items(): + # Get the sub-model type (e.g., ModelA, ModelB) + sub_model = field_type.__base__ + + # Create argument group for the sub-model + group = parser.add_argument_group(field_name, description=f"Arguments for {field_name}") + + # Loop through each field in the sub-model (e.g., ModelA, ModelB) + for sub_field_name, sub_field_type in field_type.__annotations__.items(): + field = field_type.__fields__[sub_field_name] + help_text = field.description if field.description else "No description available" + + # Handle Optional types or other generic types + origin = get_origin(sub_field_type) + if origin is Union: # Check if the type is Union (which includes Optional) + sub_field_type = next(t for t in get_args(sub_field_type) if t is not type(None)) + elif origin is List : sub_field_type = get_args(sub_field_type)[0] + + + # Map Pydantic types to argparse types + print(sub_field_type, type(sub_field_type)) + if isinstance(sub_field_type, type) and issubclass(sub_field_type, (int, float, str, bool)): + arg_type = sub_field_type + else: + arg_type = str # Default to string for unknown types + + # Add the argument for each field in the sub-model + group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text) return parser - def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser): """Broad conversion of surface level arg groups to dictionaries""" @@ -63,202 +81,4 @@ def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentPars arg_groups[group.title] = group_dict - return arg_groups - - -def add_config_args(parser: argparse.ArgumentParser): - """Adds config arguments""" - - parser.add_argument( - "--config", type=str, help="Path to an overriding config.yml file" - ) - - -def add_network_args(parser: argparse.ArgumentParser): - """Adds networking arguments""" - - network_group = parser.add_argument_group("network") - network_group.add_argument("--host", type=str, help="The IP to host on") - network_group.add_argument("--port", type=int, help="The port to host on") - network_group.add_argument( - "--disable-auth", - type=str_to_bool, - help="Disable HTTP token authenticaion with requests", - ) - network_group.add_argument( - "--send-tracebacks", - type=str_to_bool, - help="Decide whether to send error tracebacks over the API", - ) - network_group.add_argument( - "--api-servers", - type=str, - nargs="+", - help="API servers to enable. Options: (OAI, Kobold)", - ) - - -def add_model_args(parser: argparse.ArgumentParser): - """Adds model arguments""" - - model_group = parser.add_argument_group("model") - model_group.add_argument( - "--model-dir", type=str, help="Overrides the directory to look for models" - ) - model_group.add_argument("--model-name", type=str, help="An initial model to load") - model_group.add_argument( - "--use-dummy-models", - type=str_to_bool, - help="Add dummy OAI model names for API queries", - ) - model_group.add_argument( - "--use-as-default", - type=str, - nargs="+", - help="Names of args to use as a default fallback for API load requests ", - ) - model_group.add_argument( - "--max-seq-len", type=int, help="Override the maximum model sequence length" - ) - model_group.add_argument( - "--override-base-seq-len", - type=str_to_bool, - help="Overrides base model context length", - ) - model_group.add_argument( - "--tensor-parallel", - type=str_to_bool, - help="Use tensor parallelism to load models", - ) - model_group.add_argument( - "--gpu-split-auto", - type=str_to_bool, - help="Automatically allocate resources to GPUs", - ) - model_group.add_argument( - "--autosplit-reserve", - type=int, - nargs="+", - help="Reserve VRAM used for autosplit loading (in MBs) ", - ) - model_group.add_argument( - "--gpu-split", - type=float, - nargs="+", - help="An integer array of GBs of vram to split between GPUs. " - + "Ignored if gpu_split_auto is true", - ) - model_group.add_argument( - "--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb" - ) - model_group.add_argument( - "--rope-alpha", - type=argument_with_auto, - help="Sets rope_alpha for NTK", - ) - model_group.add_argument( - "--cache-mode", - type=str, - help="Set the quantization level of the K/V cache. Options: (FP16, Q8, Q6, Q4)", - ) - model_group.add_argument( - "--cache-size", - type=int, - help="The size of the prompt cache (in number of tokens) to allocate", - ) - model_group.add_argument( - "--chunk-size", - type=int, - help="Chunk size for prompt ingestion", - ) - model_group.add_argument( - "--max-batch-size", - type=int, - help="Maximum amount of prompts to process at one time", - ) - model_group.add_argument( - "--prompt-template", - type=str, - help="Set the jinja2 prompt template for chat completions", - ) - model_group.add_argument( - "--num-experts-per-token", - type=int, - help="Number of experts to use per token in MoE models", - ) - model_group.add_argument( - "--fasttensors", - type=str_to_bool, - help="Possibly increases model loading speeds", - ) - - -def add_logging_args(parser: argparse.ArgumentParser): - """Adds logging arguments""" - - logging_group = parser.add_argument_group("logging") - logging_group.add_argument( - "--log-prompt", type=str_to_bool, help="Enable prompt logging" - ) - logging_group.add_argument( - "--log-generation-params", - type=str_to_bool, - help="Enable generation parameter logging", - ) - logging_group.add_argument( - "--log-requests", - type=str_to_bool, - help="Enable request logging", - ) - - -def add_developer_args(parser: argparse.ArgumentParser): - """Adds developer-specific arguments""" - - developer_group = parser.add_argument_group("developer") - developer_group.add_argument( - "--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check" - ) - developer_group.add_argument( - "--disable-request-streaming", - type=str_to_bool, - help="Disables API request streaming", - ) - developer_group.add_argument( - "--cuda-malloc-backend", - type=str_to_bool, - help="Runs with the pytorch CUDA malloc backend", - ) - developer_group.add_argument( - "--uvloop", - type=str_to_bool, - help="Run asyncio using Uvloop or Winloop", - ) - - -def add_sampling_args(parser: argparse.ArgumentParser): - """Adds sampling-specific arguments""" - - sampling_group = parser.add_argument_group("sampling") - sampling_group.add_argument( - "--override-preset", type=str, help="Select a sampler override preset" - ) - - -def add_embeddings_args(parser: argparse.ArgumentParser): - """Adds arguments specific to embeddings""" - - embeddings_group = parser.add_argument_group("embeddings") - embeddings_group.add_argument( - "--embedding-model-dir", - type=str, - help="Overrides the directory to look for models", - ) - embeddings_group.add_argument( - "--embedding-model-name", type=str, help="An initial model to load" - ) - embeddings_group.add_argument( - "--embeddings-device", - type=str, - help="Device to use for embeddings. Options: (cpu, auto, cuda)", - ) + return arg_groups \ No newline at end of file