mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
make pydantic do all the validation
This commit is contained in:
@@ -1,75 +1,21 @@
|
||||
"""Argparser for overriding config values"""
|
||||
|
||||
import argparse
|
||||
from typing import Any, get_origin, get_args, Union, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.config_models import TabbyConfigModel
|
||||
|
||||
|
||||
def str_to_bool(value):
|
||||
"""Converts a string into a boolean value"""
|
||||
|
||||
if value.lower() in {"false", "f", "0", "no", "n"}:
|
||||
return False
|
||||
elif value.lower() in {"true", "t", "1", "yes", "y"}:
|
||||
return True
|
||||
raise ValueError(f"{value} is not a valid boolean value")
|
||||
|
||||
|
||||
def argument_with_auto(value):
|
||||
"""
|
||||
Argparse type wrapper for any argument that has an automatic option.
|
||||
|
||||
Ex. rope_alpha
|
||||
"""
|
||||
|
||||
if value == "auto":
|
||||
return "auto"
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError as ex:
|
||||
raise argparse.ArgumentTypeError(
|
||||
'This argument only takes a type of float or "auto"'
|
||||
) from ex
|
||||
|
||||
|
||||
def map_pydantic_type_to_argparse(pydantic_type: Any):
|
||||
"""
|
||||
Maps Pydantic types to argparse compatible types.
|
||||
Handles special cases like Union and List.
|
||||
"""
|
||||
|
||||
origin = get_origin(pydantic_type)
|
||||
|
||||
# Handle optional types
|
||||
if origin is Union:
|
||||
# Filter out NoneType
|
||||
pydantic_type = next(t for t in get_args(pydantic_type) if t is not type(None))
|
||||
|
||||
elif origin is List:
|
||||
pydantic_type = get_args(pydantic_type)[0] # Get the list item type
|
||||
|
||||
# Map basic types (int, float, str, bool)
|
||||
if isinstance(pydantic_type, type) and issubclass(
|
||||
pydantic_type, (int, float, str, bool)
|
||||
):
|
||||
return pydantic_type
|
||||
|
||||
return str
|
||||
|
||||
|
||||
def add_field_to_group(group, field_name, field_type, field) -> None:
|
||||
"""
|
||||
Adds a Pydantic field to an argparse argument group.
|
||||
"""
|
||||
|
||||
arg_type = map_pydantic_type_to_argparse(field_type)
|
||||
help_text = field.description if field.description else "No description available"
|
||||
|
||||
group.add_argument(f"--{field_name}", type=arg_type, help=help_text)
|
||||
group.add_argument(f"--{field_name}", help=help_text)
|
||||
|
||||
|
||||
def init_argparser() -> argparse.ArgumentParser:
|
||||
@@ -96,10 +42,7 @@ def init_argparser() -> argparse.ArgumentParser:
|
||||
)
|
||||
else:
|
||||
field_name = field_name.replace("_", "-")
|
||||
arg_type = map_pydantic_type_to_argparse(field_type)
|
||||
group.add_argument(
|
||||
f"--{field_name}", type=arg_type, help=f"Argument for {field_name}"
|
||||
)
|
||||
group.add_argument(f"--{field_name}", help=f"Argument for {field_name}")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from common.utils import unwrap
|
||||
|
||||
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
|
||||
|
||||
|
||||
class ConfigOverrideConfig(BaseModel):
|
||||
# TODO: convert this to a pathlib.path?
|
||||
config: Optional[str] = Field(
|
||||
None, description=("Path to an overriding config.yml file")
|
||||
)
|
||||
@@ -20,7 +23,7 @@ class NetworkConfig(BaseModel):
|
||||
False,
|
||||
description=("Decide whether to send error tracebacks over the API"),
|
||||
)
|
||||
api_servers: Optional[List[str]] = Field(
|
||||
api_servers: Optional[List[Literal["OAI", "Kobold"]]] = Field(
|
||||
[
|
||||
"OAI",
|
||||
],
|
||||
@@ -37,6 +40,7 @@ class LoggingConfig(BaseModel):
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
# TODO: convert this to a pathlib.path?
|
||||
model_dir: str = Field(
|
||||
"models",
|
||||
description=(
|
||||
@@ -71,6 +75,7 @@ class ModelConfig(BaseModel):
|
||||
"Max sequence length. Fetched from the model's base sequence length in "
|
||||
"config.json by default."
|
||||
),
|
||||
ge=0,
|
||||
)
|
||||
override_base_seq_len: Optional[int] = Field(
|
||||
None,
|
||||
@@ -78,6 +83,7 @@ class ModelConfig(BaseModel):
|
||||
"Overrides base model context length. WARNING: Only use this if the "
|
||||
"model's base sequence length is incorrect."
|
||||
),
|
||||
ge=0,
|
||||
)
|
||||
tensor_parallel: Optional[bool] = Field(
|
||||
False,
|
||||
@@ -114,18 +120,18 @@ class ModelConfig(BaseModel):
|
||||
"model was trained on long context with rope."
|
||||
),
|
||||
)
|
||||
rope_alpha: Optional[Union[float, str]] = Field(
|
||||
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||
1.0,
|
||||
description=(
|
||||
"Rope alpha (default: 1.0). Same as alpha_value. Set to 'auto' to auto- "
|
||||
"calculate."
|
||||
),
|
||||
)
|
||||
cache_mode: Optional[str] = Field(
|
||||
cache_mode: Optional[CACHE_SIZES] = Field(
|
||||
"FP16",
|
||||
description=(
|
||||
"Enable different cache modes for VRAM savings (default: FP16). Possible "
|
||||
"values: FP16, Q8, Q6, Q4."
|
||||
f"values: {str(CACHE_SIZES)[15:-1]}"
|
||||
),
|
||||
)
|
||||
cache_size: Optional[int] = Field(
|
||||
@@ -134,6 +140,8 @@ class ModelConfig(BaseModel):
|
||||
"Size of the prompt cache to allocate (default: max_seq_len). Must be a "
|
||||
"multiple of 256."
|
||||
),
|
||||
multiple_of=256,
|
||||
gt=0,
|
||||
)
|
||||
chunk_size: Optional[int] = Field(
|
||||
2048,
|
||||
@@ -141,6 +149,7 @@ class ModelConfig(BaseModel):
|
||||
"Chunk size for prompt ingestion (default: 2048). A lower value reduces "
|
||||
"VRAM usage but decreases ingestion speed."
|
||||
),
|
||||
gt=0,
|
||||
)
|
||||
max_batch_size: Optional[int] = Field(
|
||||
None,
|
||||
@@ -148,6 +157,7 @@ class ModelConfig(BaseModel):
|
||||
"Set the maximum number of prompts to process at one time (default: "
|
||||
"None/Automatic). Automatically calculated if left blank."
|
||||
),
|
||||
ge=1,
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
None,
|
||||
@@ -162,6 +172,7 @@ class ModelConfig(BaseModel):
|
||||
"Number of experts to use per token. Fetched from the model's "
|
||||
"config.json. For MoE models only."
|
||||
),
|
||||
ge=1,
|
||||
)
|
||||
fasttensors: Optional[bool] = Field(
|
||||
False,
|
||||
@@ -175,6 +186,7 @@ class ModelConfig(BaseModel):
|
||||
|
||||
|
||||
class DraftModelConfig(BaseModel):
|
||||
# TODO: convert this to a pathlib.path?
|
||||
draft_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
description=(
|
||||
@@ -202,11 +214,11 @@ class DraftModelConfig(BaseModel):
|
||||
"blank to auto-calculate the alpha value."
|
||||
),
|
||||
)
|
||||
draft_cache_mode: Optional[str] = Field(
|
||||
draft_cache_mode: Optional[CACHE_SIZES] = Field(
|
||||
"FP16",
|
||||
description=(
|
||||
"Cache mode for draft models to save VRAM (default: FP16). Possible "
|
||||
"values: FP16, Q8, Q6, Q4."
|
||||
f"values: {str(CACHE_SIZES)[15:-1]}"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -214,11 +226,14 @@ class DraftModelConfig(BaseModel):
|
||||
class LoraInstanceModel(BaseModel):
|
||||
name: str = Field(..., description=("Name of the LoRA model"))
|
||||
scaling: float = Field(
|
||||
1.0, description=("Scaling factor for the LoRA model (default: 1.0)")
|
||||
1.0,
|
||||
description=("Scaling factor for the LoRA model (default: 1.0)"),
|
||||
ge=0,
|
||||
)
|
||||
|
||||
|
||||
class LoraConfig(BaseModel):
|
||||
# TODO: convert this to a pathlib.path?
|
||||
lora_dir: Optional[str] = Field(
|
||||
"loras", description=("Directory to look for LoRAs (default: 'loras')")
|
||||
)
|
||||
@@ -260,13 +275,14 @@ class DeveloperConfig(BaseModel):
|
||||
|
||||
|
||||
class EmbeddingsConfig(BaseModel):
|
||||
# TODO: convert this to a pathlib.path?
|
||||
embedding_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
description=(
|
||||
"Overrides directory to look for embedding models (default: models)"
|
||||
),
|
||||
)
|
||||
embeddings_device: Optional[str] = Field(
|
||||
embeddings_device: Optional[Literal["cpu", "auto", "cuda"]] = Field(
|
||||
"cpu",
|
||||
description=(
|
||||
"Device to load embedding models on (default: cpu). Possible values: cpu, "
|
||||
|
||||
Reference in New Issue
Block a user