diff --git a/common/args.py b/common/args.py index b991103..22c7681 100644 --- a/common/args.py +++ b/common/args.py @@ -1,75 +1,21 @@ """Argparser for overriding config values""" import argparse -from typing import Any, get_origin, get_args, Union, List +from typing import Any from pydantic import BaseModel from common.config_models import TabbyConfigModel -def str_to_bool(value): - """Converts a string into a boolean value""" - - if value.lower() in {"false", "f", "0", "no", "n"}: - return False - elif value.lower() in {"true", "t", "1", "yes", "y"}: - return True - raise ValueError(f"{value} is not a valid boolean value") - - -def argument_with_auto(value): - """ - Argparse type wrapper for any argument that has an automatic option. - - Ex. rope_alpha - """ - - if value == "auto": - return "auto" - - try: - return float(value) - except ValueError as ex: - raise argparse.ArgumentTypeError( - 'This argument only takes a type of float or "auto"' - ) from ex - - -def map_pydantic_type_to_argparse(pydantic_type: Any): - """ - Maps Pydantic types to argparse compatible types. - Handles special cases like Union and List. - """ - - origin = get_origin(pydantic_type) - - # Handle optional types - if origin is Union: - # Filter out NoneType - pydantic_type = next(t for t in get_args(pydantic_type) if t is not type(None)) - - elif origin is List: - pydantic_type = get_args(pydantic_type)[0] # Get the list item type - - # Map basic types (int, float, str, bool) - if isinstance(pydantic_type, type) and issubclass( - pydantic_type, (int, float, str, bool) - ): - return pydantic_type - - return str - - def add_field_to_group(group, field_name, field_type, field) -> None: """ Adds a Pydantic field to an argparse argument group. """ - arg_type = map_pydantic_type_to_argparse(field_type) help_text = field.description if field.description else "No description available" - group.add_argument(f"--{field_name}", type=arg_type, help=help_text) + group.add_argument(f"--{field_name}", help=help_text) def init_argparser() -> argparse.ArgumentParser: @@ -96,10 +42,7 @@ def init_argparser() -> argparse.ArgumentParser: ) else: field_name = field_name.replace("_", "-") - arg_type = map_pydantic_type_to_argparse(field_type) - group.add_argument( - f"--{field_name}", type=arg_type, help=f"Argument for {field_name}" - ) + group.add_argument(f"--{field_name}", help=f"Argument for {field_name}") return parser diff --git a/common/config_models.py b/common/config_models.py index ecc13c9..8384957 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -1,10 +1,13 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing import List, Optional, Union +from typing import List, Literal, Optional, Union from common.utils import unwrap +CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] + class ConfigOverrideConfig(BaseModel): + # TODO: convert this to a pathlib.path? config: Optional[str] = Field( None, description=("Path to an overriding config.yml file") ) @@ -20,7 +23,7 @@ class NetworkConfig(BaseModel): False, description=("Decide whether to send error tracebacks over the API"), ) - api_servers: Optional[List[str]] = Field( + api_servers: Optional[List[Literal["OAI", "Kobold"]]] = Field( [ "OAI", ], @@ -37,6 +40,7 @@ class LoggingConfig(BaseModel): class ModelConfig(BaseModel): + # TODO: convert this to a pathlib.path? model_dir: str = Field( "models", description=( @@ -71,6 +75,7 @@ class ModelConfig(BaseModel): "Max sequence length. Fetched from the model's base sequence length in " "config.json by default." ), + ge=0, ) override_base_seq_len: Optional[int] = Field( None, @@ -78,6 +83,7 @@ class ModelConfig(BaseModel): "Overrides base model context length. WARNING: Only use this if the " "model's base sequence length is incorrect." ), + ge=0, ) tensor_parallel: Optional[bool] = Field( False, @@ -114,18 +120,18 @@ class ModelConfig(BaseModel): "model was trained on long context with rope." ), ) - rope_alpha: Optional[Union[float, str]] = Field( + rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( 1.0, description=( "Rope alpha (default: 1.0). Same as alpha_value. Set to 'auto' to auto- " "calculate." ), ) - cache_mode: Optional[str] = Field( + cache_mode: Optional[CACHE_SIZES] = Field( "FP16", description=( "Enable different cache modes for VRAM savings (default: FP16). Possible " - "values: FP16, Q8, Q6, Q4." + f"values: {str(CACHE_SIZES)[15:-1]}" ), ) cache_size: Optional[int] = Field( @@ -134,6 +140,8 @@ class ModelConfig(BaseModel): "Size of the prompt cache to allocate (default: max_seq_len). Must be a " "multiple of 256." ), + multiple_of=256, + gt=0, ) chunk_size: Optional[int] = Field( 2048, @@ -141,6 +149,7 @@ class ModelConfig(BaseModel): "Chunk size for prompt ingestion (default: 2048). A lower value reduces " "VRAM usage but decreases ingestion speed." ), + gt=0, ) max_batch_size: Optional[int] = Field( None, @@ -148,6 +157,7 @@ class ModelConfig(BaseModel): "Set the maximum number of prompts to process at one time (default: " "None/Automatic). Automatically calculated if left blank." ), + ge=1, ) prompt_template: Optional[str] = Field( None, @@ -162,6 +172,7 @@ class ModelConfig(BaseModel): "Number of experts to use per token. Fetched from the model's " "config.json. For MoE models only." ), + ge=1, ) fasttensors: Optional[bool] = Field( False, @@ -175,6 +186,7 @@ class ModelConfig(BaseModel): class DraftModelConfig(BaseModel): + # TODO: convert this to a pathlib.path? draft_model_dir: Optional[str] = Field( "models", description=( @@ -202,11 +214,11 @@ class DraftModelConfig(BaseModel): "blank to auto-calculate the alpha value." ), ) - draft_cache_mode: Optional[str] = Field( + draft_cache_mode: Optional[CACHE_SIZES] = Field( "FP16", description=( "Cache mode for draft models to save VRAM (default: FP16). Possible " - "values: FP16, Q8, Q6, Q4." + f"values: {str(CACHE_SIZES)[15:-1]}" ), ) @@ -214,11 +226,14 @@ class DraftModelConfig(BaseModel): class LoraInstanceModel(BaseModel): name: str = Field(..., description=("Name of the LoRA model")) scaling: float = Field( - 1.0, description=("Scaling factor for the LoRA model (default: 1.0)") + 1.0, + description=("Scaling factor for the LoRA model (default: 1.0)"), + ge=0, ) class LoraConfig(BaseModel): + # TODO: convert this to a pathlib.path? lora_dir: Optional[str] = Field( "loras", description=("Directory to look for LoRAs (default: 'loras')") ) @@ -260,13 +275,14 @@ class DeveloperConfig(BaseModel): class EmbeddingsConfig(BaseModel): + # TODO: convert this to a pathlib.path? embedding_model_dir: Optional[str] = Field( "models", description=( "Overrides directory to look for embedding models (default: models)" ), ) - embeddings_device: Optional[str] = Field( + embeddings_device: Optional[Literal["cpu", "auto", "cuda"]] = Field( "cpu", description=( "Device to load embedding models on (default: cpu). Possible values: cpu, "