mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
make pydantic do all the validation
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from common.utils import unwrap
|
||||
|
||||
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
|
||||
|
||||
|
||||
class ConfigOverrideConfig(BaseModel):
|
||||
# TODO: convert this to a pathlib.path?
|
||||
config: Optional[str] = Field(
|
||||
None, description=("Path to an overriding config.yml file")
|
||||
)
|
||||
@@ -20,7 +23,7 @@ class NetworkConfig(BaseModel):
|
||||
False,
|
||||
description=("Decide whether to send error tracebacks over the API"),
|
||||
)
|
||||
api_servers: Optional[List[str]] = Field(
|
||||
api_servers: Optional[List[Literal["OAI", "Kobold"]]] = Field(
|
||||
[
|
||||
"OAI",
|
||||
],
|
||||
@@ -37,6 +40,7 @@ class LoggingConfig(BaseModel):
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
# TODO: convert this to a pathlib.path?
|
||||
model_dir: str = Field(
|
||||
"models",
|
||||
description=(
|
||||
@@ -71,6 +75,7 @@ class ModelConfig(BaseModel):
|
||||
"Max sequence length. Fetched from the model's base sequence length in "
|
||||
"config.json by default."
|
||||
),
|
||||
ge=0,
|
||||
)
|
||||
override_base_seq_len: Optional[int] = Field(
|
||||
None,
|
||||
@@ -78,6 +83,7 @@ class ModelConfig(BaseModel):
|
||||
"Overrides base model context length. WARNING: Only use this if the "
|
||||
"model's base sequence length is incorrect."
|
||||
),
|
||||
ge=0,
|
||||
)
|
||||
tensor_parallel: Optional[bool] = Field(
|
||||
False,
|
||||
@@ -114,18 +120,18 @@ class ModelConfig(BaseModel):
|
||||
"model was trained on long context with rope."
|
||||
),
|
||||
)
|
||||
rope_alpha: Optional[Union[float, str]] = Field(
|
||||
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||
1.0,
|
||||
description=(
|
||||
"Rope alpha (default: 1.0). Same as alpha_value. Set to 'auto' to auto- "
|
||||
"calculate."
|
||||
),
|
||||
)
|
||||
cache_mode: Optional[str] = Field(
|
||||
cache_mode: Optional[CACHE_SIZES] = Field(
|
||||
"FP16",
|
||||
description=(
|
||||
"Enable different cache modes for VRAM savings (default: FP16). Possible "
|
||||
"values: FP16, Q8, Q6, Q4."
|
||||
f"values: {str(CACHE_SIZES)[15:-1]}"
|
||||
),
|
||||
)
|
||||
cache_size: Optional[int] = Field(
|
||||
@@ -134,6 +140,8 @@ class ModelConfig(BaseModel):
|
||||
"Size of the prompt cache to allocate (default: max_seq_len). Must be a "
|
||||
"multiple of 256."
|
||||
),
|
||||
multiple_of=256,
|
||||
gt=0,
|
||||
)
|
||||
chunk_size: Optional[int] = Field(
|
||||
2048,
|
||||
@@ -141,6 +149,7 @@ class ModelConfig(BaseModel):
|
||||
"Chunk size for prompt ingestion (default: 2048). A lower value reduces "
|
||||
"VRAM usage but decreases ingestion speed."
|
||||
),
|
||||
gt=0,
|
||||
)
|
||||
max_batch_size: Optional[int] = Field(
|
||||
None,
|
||||
@@ -148,6 +157,7 @@ class ModelConfig(BaseModel):
|
||||
"Set the maximum number of prompts to process at one time (default: "
|
||||
"None/Automatic). Automatically calculated if left blank."
|
||||
),
|
||||
ge=1,
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
None,
|
||||
@@ -162,6 +172,7 @@ class ModelConfig(BaseModel):
|
||||
"Number of experts to use per token. Fetched from the model's "
|
||||
"config.json. For MoE models only."
|
||||
),
|
||||
ge=1,
|
||||
)
|
||||
fasttensors: Optional[bool] = Field(
|
||||
False,
|
||||
@@ -175,6 +186,7 @@ class ModelConfig(BaseModel):
|
||||
|
||||
|
||||
class DraftModelConfig(BaseModel):
|
||||
# TODO: convert this to a pathlib.path?
|
||||
draft_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
description=(
|
||||
@@ -202,11 +214,11 @@ class DraftModelConfig(BaseModel):
|
||||
"blank to auto-calculate the alpha value."
|
||||
),
|
||||
)
|
||||
draft_cache_mode: Optional[str] = Field(
|
||||
draft_cache_mode: Optional[CACHE_SIZES] = Field(
|
||||
"FP16",
|
||||
description=(
|
||||
"Cache mode for draft models to save VRAM (default: FP16). Possible "
|
||||
"values: FP16, Q8, Q6, Q4."
|
||||
f"values: {str(CACHE_SIZES)[15:-1]}"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -214,11 +226,14 @@ class DraftModelConfig(BaseModel):
|
||||
class LoraInstanceModel(BaseModel):
|
||||
name: str = Field(..., description=("Name of the LoRA model"))
|
||||
scaling: float = Field(
|
||||
1.0, description=("Scaling factor for the LoRA model (default: 1.0)")
|
||||
1.0,
|
||||
description=("Scaling factor for the LoRA model (default: 1.0)"),
|
||||
ge=0,
|
||||
)
|
||||
|
||||
|
||||
class LoraConfig(BaseModel):
|
||||
# TODO: convert this to a pathlib.path?
|
||||
lora_dir: Optional[str] = Field(
|
||||
"loras", description=("Directory to look for LoRAs (default: 'loras')")
|
||||
)
|
||||
@@ -260,13 +275,14 @@ class DeveloperConfig(BaseModel):
|
||||
|
||||
|
||||
class EmbeddingsConfig(BaseModel):
|
||||
# TODO: convert this to a pathlib.path?
|
||||
embedding_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
description=(
|
||||
"Overrides directory to look for embedding models (default: models)"
|
||||
),
|
||||
)
|
||||
embeddings_device: Optional[str] = Field(
|
||||
embeddings_device: Optional[Literal["cpu", "auto", "cuda"]] = Field(
|
||||
"cpu",
|
||||
description=(
|
||||
"Device to load embedding models on (default: cpu). Possible values: cpu, "
|
||||
|
||||
Reference in New Issue
Block a user