from pydantic import BaseModel, ConfigDict, Field, model_validator from typing import List, Optional, Union from common.utils import unwrap class ConfigOverrideConfig(BaseModel): config: Optional[str] = Field( None, description=("Path to an overriding config.yml file") ) class NetworkConfig(BaseModel): host: Optional[str] = Field("127.0.0.1", description=("The IP to host on")) port: Optional[int] = Field(5000, description=("The port to host on")) disable_auth: Optional[bool] = Field( False, description=("Disable HTTP token authentication with requests") ) send_tracebacks: Optional[bool] = Field( False, description=("Decide whether to send error tracebacks over the API"), ) api_servers: Optional[List[str]] = Field( [ "OAI", ], description=("API servers to enable. Options: (OAI, Kobold)"), ) class LoggingConfig(BaseModel): log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging")) log_generation_params: Optional[bool] = Field( False, description=("Enable generation parameter logging") ) log_requests: Optional[bool] = Field(False, description=("Enable request logging")) class ModelConfig(BaseModel): model_dir: str = Field( "models", description=( "Overrides the directory to look for models (default: models). Windows" "users, do NOT put this path in quotes." ), ) use_dummy_models: Optional[bool] = Field( False, description=( "Sends dummy model names when the models endpoint is queried. Enable this" "if looking for specific OAI models." ), ) model_name: Optional[str] = Field( None, description=( "An initial model to load. Make sure the model is located in the model" "directory! REQUIRED: This must be filled out to load a model on startup." ), ) use_as_default: List[str] = Field( default_factory=list, description=( "Names of args to use as a default fallback for API load requests" "(default: []). Example: ['max_seq_len', 'cache_mode']" ), ) max_seq_len: Optional[int] = Field( None, description=( "Max sequence length. Fetched from the model's base sequence length in" "config.json by default." ), ) override_base_seq_len: Optional[int] = Field( None, description=( "Overrides base model context length. WARNING: Only use this if the" "model's base sequence length is incorrect." ), ) tensor_parallel: Optional[bool] = Field( False, description=( "Load model with tensor parallelism. Fallback to autosplit if GPU split" "isn't provided." ), ) gpu_split_auto: Optional[bool] = Field( True, description=( "Automatically allocate resources to GPUs (default: True). Not parsed for" "single GPU users." ), ) autosplit_reserve: List[int] = Field( [96], description=( "Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0)." "Represented as an array of MB per GPU." ), ) gpu_split: List[float] = Field( default_factory=list, description=( "An integer array of GBs of VRAM to split between GPUs (default: [])." "Used with tensor parallelism." ), ) rope_scale: Optional[float] = Field( 1.0, description=( "Rope scale (default: 1.0). Same as compress_pos_emb. Only use if the" "model was trained on long context with rope." ), ) rope_alpha: Optional[Union[float, str]] = Field( 1.0, description=( "Rope alpha (default: 1.0). Same as alpha_value. Set to 'auto' to auto-" "calculate." ), ) cache_mode: Optional[str] = Field( "FP16", description=( "Enable different cache modes for VRAM savings (default: FP16). Possible" "values: FP16, Q8, Q6, Q4." ), ) cache_size: Optional[int] = Field( None, description=( "Size of the prompt cache to allocate (default: max_seq_len). Must be a" "multiple of 256." ), ) chunk_size: Optional[int] = Field( 2048, description=( "Chunk size for prompt ingestion (default: 2048). A lower value reduces" "VRAM usage but decreases ingestion speed." ), ) max_batch_size: Optional[int] = Field( None, description=( "Set the maximum number of prompts to process at one time (default:" "None/Automatic). Automatically calculated if left blank." ), ) prompt_template: Optional[str] = Field( None, description=( "Set the prompt template for this model. If empty, attempts to look for" "the model's chat template." ), ) num_experts_per_token: Optional[int] = Field( None, description=( "Number of experts to use per token. Fetched from the model's" "config.json. For MoE models only." ), ) fasttensors: Optional[bool] = Field( False, description=( "Enables fasttensors to possibly increase model loading speeds (default:" "False)." ), ) model_config = ConfigDict(protected_namespaces=()) class DraftModelConfig(BaseModel): draft_model_dir: Optional[str] = Field( "models", description=( "Overrides the directory to look for draft models (default: models)" ), ) draft_model_name: Optional[str] = Field( None, description=( "An initial draft model to load. Ensure the model is in the model" "directory." ), ) draft_rope_scale: Optional[float] = Field( 1.0, description=( "Rope scale for draft models (default: 1.0). Same as compress_pos_emb." "Use if the draft model was trained on long context with rope." ), ) draft_rope_alpha: Optional[float] = Field( None, description=( "Rope alpha for draft models (default: None). Same as alpha_value. Leave" "blank to auto-calculate the alpha value." ), ) draft_cache_mode: Optional[str] = Field( "FP16", description=( "Cache mode for draft models to save VRAM (default: FP16). Possible" "values: FP16, Q8, Q6, Q4." ), ) class LoraInstanceModel(BaseModel): name: str = Field(..., description=("Name of the LoRA model")) scaling: float = Field( 1.0, description=("Scaling factor for the LoRA model (default: 1.0)") ) class LoraConfig(BaseModel): lora_dir: Optional[str] = Field( "loras", description=("Directory to look for LoRAs (default: 'loras')") ) loras: Optional[List[LoraInstanceModel]] = Field( None, description=( "List of LoRAs to load and associated scaling factors (default scaling:" "1.0)" ), ) class SamplingConfig(BaseModel): override_preset: Optional[str] = Field( None, description=("Select a sampler override preset") ) class DeveloperConfig(BaseModel): unsafe_launch: Optional[bool] = Field( False, description=("Skip Exllamav2 version check") ) disable_request_streaming: Optional[bool] = Field( False, description=("Disables API request streaming") ) cuda_malloc_backend: Optional[bool] = Field( False, description=("Runs with the pytorch CUDA malloc backend") ) uvloop: Optional[bool] = Field( False, description=("Run asyncio using Uvloop or Winloop") ) realtime_process_priority: Optional[bool] = Field( False, description=( "Set process to use a higher priority For realtime process priority, run" "as administrator or sudo Otherwise, the priority will be set to high" ), ) class EmbeddingsConfig(BaseModel): embedding_model_dir: Optional[str] = Field( "models", description=( "Overrides directory to look for embedding models (default: models)" ), ) embeddings_device: Optional[str] = Field( "cpu", description=( "Device to load embedding models on (default: cpu). Possible values: cpu," "auto, cuda. If using an AMD GPU, set this value to 'cuda'." ), ) embedding_model_name: Optional[str] = Field( None, description=("The embeddings model to load") ) class TabbyConfigModel(BaseModel): config: ConfigOverrideConfig = Field( default_factory=ConfigOverrideConfig.model_construct ) network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct) logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct) model: ModelConfig = Field(default_factory=ModelConfig.model_construct) draft_model: DraftModelConfig = Field( default_factory=DraftModelConfig.model_construct ) lora: LoraConfig = Field(default_factory=LoraConfig.model_construct) sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct) developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct) embeddings: EmbeddingsConfig = Field( default_factory=EmbeddingsConfig.model_construct ) @model_validator(mode="before") def set_defaults(cls, values): for field_name, field_value in values.items(): if field_value is None: default_instance = cls.__annotations__[field_name]().dict() values[field_name] = cls.__annotations__[field_name](**default_instance) return values model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) def generate_config_file(filename="config_sample.yml", indentation=2): schema = TabbyConfigModel.model_json_schema() def dump_def(id: str, indent=2): yaml = "" indent = " " * indentation * indent id = id.split("/")[-1] section = schema["$defs"][id]["properties"] for property in section.keys(): # get type comment = section[property]["description"] yaml += f"{indent}# {comment}\n" value = unwrap(section[property].get("default"), "") yaml += f"{indent}{property}: {value}\n\n" return yaml + "\n" yaml = "" for section in schema["properties"].keys(): yaml += f"{section}:\n" yaml += dump_def(schema["properties"][section]["$ref"]) yaml += "\n" with open(filename, "w") as f: f.write(yaml) # generate_config_file("test.yml")