make pydantic do all the validation

This commit is contained in:
TerminalMan
2024-09-13 10:21:27 +01:00
parent d5b3fde319
commit dc4946b565
2 changed files with 28 additions and 69 deletions

View File

@@ -1,75 +1,21 @@
"""Argparser for overriding config values"""
import argparse
from typing import Any, get_origin, get_args, Union, List
from typing import Any
from pydantic import BaseModel
from common.config_models import TabbyConfigModel
def str_to_bool(value):
"""Converts a string into a boolean value"""
if value.lower() in {"false", "f", "0", "no", "n"}:
return False
elif value.lower() in {"true", "t", "1", "yes", "y"}:
return True
raise ValueError(f"{value} is not a valid boolean value")
def argument_with_auto(value):
"""
Argparse type wrapper for any argument that has an automatic option.
Ex. rope_alpha
"""
if value == "auto":
return "auto"
try:
return float(value)
except ValueError as ex:
raise argparse.ArgumentTypeError(
'This argument only takes a type of float or "auto"'
) from ex
def map_pydantic_type_to_argparse(pydantic_type: Any):
"""
Maps Pydantic types to argparse compatible types.
Handles special cases like Union and List.
"""
origin = get_origin(pydantic_type)
# Handle optional types
if origin is Union:
# Filter out NoneType
pydantic_type = next(t for t in get_args(pydantic_type) if t is not type(None))
elif origin is List:
pydantic_type = get_args(pydantic_type)[0] # Get the list item type
# Map basic types (int, float, str, bool)
if isinstance(pydantic_type, type) and issubclass(
pydantic_type, (int, float, str, bool)
):
return pydantic_type
return str
def add_field_to_group(group, field_name, field_type, field) -> None:
"""
Adds a Pydantic field to an argparse argument group.
"""
arg_type = map_pydantic_type_to_argparse(field_type)
help_text = field.description if field.description else "No description available"
group.add_argument(f"--{field_name}", type=arg_type, help=help_text)
group.add_argument(f"--{field_name}", help=help_text)
def init_argparser() -> argparse.ArgumentParser:
@@ -96,10 +42,7 @@ def init_argparser() -> argparse.ArgumentParser:
)
else:
field_name = field_name.replace("_", "-")
arg_type = map_pydantic_type_to_argparse(field_type)
group.add_argument(
f"--{field_name}", type=arg_type, help=f"Argument for {field_name}"
)
group.add_argument(f"--{field_name}", help=f"Argument for {field_name}")
return parser

View File

@@ -1,10 +1,13 @@
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing import List, Optional, Union
from typing import List, Literal, Optional, Union
from common.utils import unwrap
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
class ConfigOverrideConfig(BaseModel):
# TODO: convert this to a pathlib.path?
config: Optional[str] = Field(
None, description=("Path to an overriding config.yml file")
)
@@ -20,7 +23,7 @@ class NetworkConfig(BaseModel):
False,
description=("Decide whether to send error tracebacks over the API"),
)
api_servers: Optional[List[str]] = Field(
api_servers: Optional[List[Literal["OAI", "Kobold"]]] = Field(
[
"OAI",
],
@@ -37,6 +40,7 @@ class LoggingConfig(BaseModel):
class ModelConfig(BaseModel):
# TODO: convert this to a pathlib.path?
model_dir: str = Field(
"models",
description=(
@@ -71,6 +75,7 @@ class ModelConfig(BaseModel):
"Max sequence length. Fetched from the model's base sequence length in "
"config.json by default."
),
ge=0,
)
override_base_seq_len: Optional[int] = Field(
None,
@@ -78,6 +83,7 @@ class ModelConfig(BaseModel):
"Overrides base model context length. WARNING: Only use this if the "
"model's base sequence length is incorrect."
),
ge=0,
)
tensor_parallel: Optional[bool] = Field(
False,
@@ -114,18 +120,18 @@ class ModelConfig(BaseModel):
"model was trained on long context with rope."
),
)
rope_alpha: Optional[Union[float, str]] = Field(
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
1.0,
description=(
"Rope alpha (default: 1.0). Same as alpha_value. Set to 'auto' to auto- "
"calculate."
),
)
cache_mode: Optional[str] = Field(
cache_mode: Optional[CACHE_SIZES] = Field(
"FP16",
description=(
"Enable different cache modes for VRAM savings (default: FP16). Possible "
"values: FP16, Q8, Q6, Q4."
f"values: {str(CACHE_SIZES)[15:-1]}"
),
)
cache_size: Optional[int] = Field(
@@ -134,6 +140,8 @@ class ModelConfig(BaseModel):
"Size of the prompt cache to allocate (default: max_seq_len). Must be a "
"multiple of 256."
),
multiple_of=256,
gt=0,
)
chunk_size: Optional[int] = Field(
2048,
@@ -141,6 +149,7 @@ class ModelConfig(BaseModel):
"Chunk size for prompt ingestion (default: 2048). A lower value reduces "
"VRAM usage but decreases ingestion speed."
),
gt=0,
)
max_batch_size: Optional[int] = Field(
None,
@@ -148,6 +157,7 @@ class ModelConfig(BaseModel):
"Set the maximum number of prompts to process at one time (default: "
"None/Automatic). Automatically calculated if left blank."
),
ge=1,
)
prompt_template: Optional[str] = Field(
None,
@@ -162,6 +172,7 @@ class ModelConfig(BaseModel):
"Number of experts to use per token. Fetched from the model's "
"config.json. For MoE models only."
),
ge=1,
)
fasttensors: Optional[bool] = Field(
False,
@@ -175,6 +186,7 @@ class ModelConfig(BaseModel):
class DraftModelConfig(BaseModel):
# TODO: convert this to a pathlib.path?
draft_model_dir: Optional[str] = Field(
"models",
description=(
@@ -202,11 +214,11 @@ class DraftModelConfig(BaseModel):
"blank to auto-calculate the alpha value."
),
)
draft_cache_mode: Optional[str] = Field(
draft_cache_mode: Optional[CACHE_SIZES] = Field(
"FP16",
description=(
"Cache mode for draft models to save VRAM (default: FP16). Possible "
"values: FP16, Q8, Q6, Q4."
f"values: {str(CACHE_SIZES)[15:-1]}"
),
)
@@ -214,11 +226,14 @@ class DraftModelConfig(BaseModel):
class LoraInstanceModel(BaseModel):
name: str = Field(..., description=("Name of the LoRA model"))
scaling: float = Field(
1.0, description=("Scaling factor for the LoRA model (default: 1.0)")
1.0,
description=("Scaling factor for the LoRA model (default: 1.0)"),
ge=0,
)
class LoraConfig(BaseModel):
# TODO: convert this to a pathlib.path?
lora_dir: Optional[str] = Field(
"loras", description=("Directory to look for LoRAs (default: 'loras')")
)
@@ -260,13 +275,14 @@ class DeveloperConfig(BaseModel):
class EmbeddingsConfig(BaseModel):
# TODO: convert this to a pathlib.path?
embedding_model_dir: Optional[str] = Field(
"models",
description=(
"Overrides directory to look for embedding models (default: models)"
),
)
embeddings_device: Optional[str] = Field(
embeddings_device: Optional[Literal["cpu", "auto", "cuda"]] = Field(
"cpu",
description=(
"Device to load embedding models on (default: cpu). Possible values: cpu, "