Merge pull request #189 from SecretiveShell/pydantic-config

Update the config system to use Pydantic internally, bridging the gap between the YAML and args. YAML is still the preferred method to configure TabbyAPI, but args are no longer separately maintained.
This commit is contained in:
Brian Dashore
2024-09-18 20:41:41 -04:00
committed by GitHub
22 changed files with 1095 additions and 635 deletions

View File

@@ -48,10 +48,8 @@ jobs:
npm install @redocly/cli -g
- name: Export OpenAPI docs
run: |
EXPORT_OPENAPI=1 python main.py
mv openapi.json openapi-oai.json
EXPORT_OPENAPI=1 python main.py --api-servers kobold
mv openapi.json openapi-kobold.json
python main.py --export-openapi true --openapi-export-path "openapi-oai.json" --api-servers OAI
python main.py --export-openapi true --openapi-export-path "openapi-kobold.json" --api-servers kobold
- name: Build and store Redocly site
run: |
mkdir static

3
.gitignore vendored
View File

@@ -213,3 +213,6 @@ openapi.json
# Infinity-emb cache
.infinity_cache/
# Backup files
*.bak

View File

@@ -30,7 +30,7 @@ from itertools import zip_longest
from loguru import logger
from typing import List, Optional, Union
import yaml
from ruamel.yaml import YAML
from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
@@ -379,7 +379,10 @@ class ExllamaV2Container:
override_config_path, "r", encoding="utf8"
) as override_config_file:
contents = await override_config_file.read()
override_args = unwrap(yaml.safe_load(contents), {})
# Create a temporary YAML parser
yaml = YAML(typ="safe")
override_args = unwrap(yaml.load(contents), {})
# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})

27
common/actions.py Normal file
View File

@@ -0,0 +1,27 @@
import json
from loguru import logger
from common.tabby_config import config, generate_config_file
from endpoints.server import export_openapi
def branch_to_actions() -> bool:
"""Checks if a optional action needs to be run."""
if config.actions.export_openapi:
openapi_json = export_openapi()
with open(config.actions.openapi_export_path, "w") as f:
f.write(json.dumps(openapi_json))
logger.info(
"Successfully wrote OpenAPI spec to "
+ f"{config.actions.openapi_export_path}"
)
elif config.actions.export_config:
generate_config_file(filename=config.actions.config_export_path)
else:
# did not branch
return False
# branched and ran an action
return True

View File

@@ -1,56 +1,60 @@
"""Argparser for overriding config values"""
import argparse
from pydantic import BaseModel
from common.config_models import TabbyConfigModel
from common.utils import is_list_type, unwrap_optional_type
def str_to_bool(value):
"""Converts a string into a boolean value"""
if value.lower() in {"false", "f", "0", "no", "n"}:
return False
elif value.lower() in {"true", "t", "1", "yes", "y"}:
return True
raise ValueError(f"{value} is not a valid boolean value")
def argument_with_auto(value):
def add_field_to_group(group, field_name, field_type, field) -> None:
"""
Argparse type wrapper for any argument that has an automatic option.
Ex. rope_alpha
Adds a Pydantic field to an argparse argument group.
"""
if value == "auto":
return "auto"
kwargs = {
"help": field.description if field.description else "No description available",
}
try:
return float(value)
except ValueError as ex:
raise argparse.ArgumentTypeError(
'This argument only takes a type of float or "auto"'
) from ex
# If the inner type contains a list, specify argparse as such
if is_list_type(field_type):
kwargs["nargs"] = "+"
group.add_argument(f"--{field_name}", **kwargs)
def init_argparser():
"""Creates an argument parser that any function can use"""
def init_argparser() -> argparse.ArgumentParser:
"""
Initializes an argparse parser based on a Pydantic config schema.
"""
parser = argparse.ArgumentParser(
epilog="NOTE: These args serve to override parts of the config. "
+ "It's highly recommended to edit config.yml for all options and "
+ "better descriptions!"
)
add_network_args(parser)
add_model_args(parser)
add_embeddings_args(parser)
add_logging_args(parser)
add_developer_args(parser)
add_sampling_args(parser)
add_config_args(parser)
parser = argparse.ArgumentParser(description="TabbyAPI server")
# Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = unwrap_optional_type(field_info.annotation)
group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}"
)
# Check if the field_type is a Pydantic model
if issubclass(field_type, BaseModel):
for sub_field_name, sub_field_info in field_type.model_fields.items():
sub_field_name = sub_field_name.replace("_", "-")
sub_field_type = sub_field_info.annotation
add_field_to_group(
group, sub_field_name, sub_field_type, sub_field_info
)
else:
field_name = field_name.replace("_", "-")
group.add_argument(f"--{field_name}", help=f"Argument for {field_name}")
return parser
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
def convert_args_to_dict(
args: argparse.Namespace, parser: argparse.ArgumentParser
) -> dict:
"""Broad conversion of surface level arg groups to dictionaries"""
arg_groups = {}
@@ -64,201 +68,3 @@ def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentPars
arg_groups[group.title] = group_dict
return arg_groups
def add_config_args(parser: argparse.ArgumentParser):
"""Adds config arguments"""
parser.add_argument(
"--config", type=str, help="Path to an overriding config.yml file"
)
def add_network_args(parser: argparse.ArgumentParser):
"""Adds networking arguments"""
network_group = parser.add_argument_group("network")
network_group.add_argument("--host", type=str, help="The IP to host on")
network_group.add_argument("--port", type=int, help="The port to host on")
network_group.add_argument(
"--disable-auth",
type=str_to_bool,
help="Disable HTTP token authenticaion with requests",
)
network_group.add_argument(
"--send-tracebacks",
type=str_to_bool,
help="Decide whether to send error tracebacks over the API",
)
network_group.add_argument(
"--api-servers",
type=str,
nargs="+",
help="API servers to enable. Options: (OAI, Kobold)",
)
def add_model_args(parser: argparse.ArgumentParser):
"""Adds model arguments"""
model_group = parser.add_argument_group("model")
model_group.add_argument(
"--model-dir", type=str, help="Overrides the directory to look for models"
)
model_group.add_argument("--model-name", type=str, help="An initial model to load")
model_group.add_argument(
"--use-dummy-models",
type=str_to_bool,
help="Add dummy OAI model names for API queries",
)
model_group.add_argument(
"--use-as-default",
type=str,
nargs="+",
help="Names of args to use as a default fallback for API load requests ",
)
model_group.add_argument(
"--max-seq-len", type=int, help="Override the maximum model sequence length"
)
model_group.add_argument(
"--override-base-seq-len",
type=str_to_bool,
help="Overrides base model context length",
)
model_group.add_argument(
"--tensor-parallel",
type=str_to_bool,
help="Use tensor parallelism to load models",
)
model_group.add_argument(
"--gpu-split-auto",
type=str_to_bool,
help="Automatically allocate resources to GPUs",
)
model_group.add_argument(
"--autosplit-reserve",
type=int,
nargs="+",
help="Reserve VRAM used for autosplit loading (in MBs) ",
)
model_group.add_argument(
"--gpu-split",
type=float,
nargs="+",
help="An integer array of GBs of vram to split between GPUs. "
+ "Ignored if gpu_split_auto is true",
)
model_group.add_argument(
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
)
model_group.add_argument(
"--rope-alpha",
type=argument_with_auto,
help="Sets rope_alpha for NTK",
)
model_group.add_argument(
"--cache-mode",
type=str,
help="Set the quantization level of the K/V cache. Options: (FP16, Q8, Q6, Q4)",
)
model_group.add_argument(
"--cache-size",
type=int,
help="The size of the prompt cache (in number of tokens) to allocate",
)
model_group.add_argument(
"--chunk-size",
type=int,
help="Chunk size for prompt ingestion",
)
model_group.add_argument(
"--max-batch-size",
type=int,
help="Maximum amount of prompts to process at one time",
)
model_group.add_argument(
"--prompt-template",
type=str,
help="Set the jinja2 prompt template for chat completions",
)
model_group.add_argument(
"--num-experts-per-token",
type=int,
help="Number of experts to use per token in MoE models",
)
model_group.add_argument(
"--fasttensors",
type=str_to_bool,
help="Possibly increases model loading speeds",
)
def add_logging_args(parser: argparse.ArgumentParser):
"""Adds logging arguments"""
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
"--log-prompt", type=str_to_bool, help="Enable prompt logging"
)
logging_group.add_argument(
"--log-generation-params",
type=str_to_bool,
help="Enable generation parameter logging",
)
logging_group.add_argument(
"--log-requests",
type=str_to_bool,
help="Enable request logging",
)
def add_developer_args(parser: argparse.ArgumentParser):
"""Adds developer-specific arguments"""
developer_group = parser.add_argument_group("developer")
developer_group.add_argument(
"--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check"
)
developer_group.add_argument(
"--disable-request-streaming",
type=str_to_bool,
help="Disables API request streaming",
)
developer_group.add_argument(
"--cuda-malloc-backend",
type=str_to_bool,
help="Runs with the pytorch CUDA malloc backend",
)
developer_group.add_argument(
"--uvloop",
type=str_to_bool,
help="Run asyncio using Uvloop or Winloop",
)
def add_sampling_args(parser: argparse.ArgumentParser):
"""Adds sampling-specific arguments"""
sampling_group = parser.add_argument_group("sampling")
sampling_group.add_argument(
"--override-preset", type=str, help="Select a sampler override preset"
)
def add_embeddings_args(parser: argparse.ArgumentParser):
"""Adds arguments specific to embeddings"""
embeddings_group = parser.add_argument_group("embeddings")
embeddings_group.add_argument(
"--embedding-model-dir",
type=str,
help="Overrides the directory to look for models",
)
embeddings_group.add_argument(
"--embedding-model-name", type=str, help="An initial model to load"
)
embeddings_group.add_argument(
"--embeddings-device",
type=str,
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
)

View File

@@ -4,8 +4,9 @@ application, it should be fine.
"""
import aiofiles
import io
import secrets
import yaml
from ruamel.yaml import YAML
from fastapi import Header, HTTPException, Request
from pydantic import BaseModel
from loguru import logger
@@ -57,10 +58,13 @@ async def load_auth_keys(disable_from_config: bool):
return
# Create a temporary YAML parser
yaml = YAML(typ=["rt", "safe"])
try:
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
contents = await auth_file.read()
auth_keys_dict = yaml.safe_load(contents)
auth_keys_dict = yaml.load(contents)
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except FileNotFoundError:
new_auth_keys = AuthKeys(
@@ -69,10 +73,10 @@ async def load_auth_keys(disable_from_config: bool):
AUTH_KEYS = new_auth_keys
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
new_auth_yaml = yaml.safe_dump(
AUTH_KEYS.model_dump(), default_flow_style=False
)
await auth_file.write(new_auth_yaml)
string_stream = io.StringIO()
yaml.dump(AUTH_KEYS.model_dump(), string_stream)
await auth_file.write(string_stream.getvalue())
logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"

469
common/config_models.py Normal file
View File

@@ -0,0 +1,469 @@
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from typing import List, Literal, Optional, Union
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
class Metadata(BaseModel):
"""metadata model for config options"""
include_in_config: Optional[bool] = Field(True)
class BaseConfigModel(BaseModel):
"""Base model for config models with added metadata"""
_metadata: Metadata = PrivateAttr(Metadata())
class ConfigOverrideConfig(BaseConfigModel):
"""Model for overriding a provided config file."""
# TODO: convert this to a pathlib.path?
config: Optional[str] = Field(
None, description=("Path to an overriding config.yml file")
)
_metadata: Metadata = PrivateAttr(Metadata(include_in_config=False))
class UtilityActions(BaseConfigModel):
"""Model used for arg actions."""
# YAML export options
export_config: Optional[str] = Field(
None, description="generate a template config file"
)
config_export_path: Optional[Path] = Field(
"config_sample.yml", description="path to export configuration file to"
)
# OpenAPI JSON export options
export_openapi: Optional[bool] = Field(
False, description="export openapi schema files"
)
openapi_export_path: Optional[Path] = Field(
"openapi.json", description="path to export openapi schema to"
)
_metadata: Metadata = PrivateAttr(Metadata(include_in_config=False))
class NetworkConfig(BaseConfigModel):
"""Options for networking"""
host: Optional[str] = Field(
"127.0.0.1",
description=(
"The IP to host on (default: 127.0.0.1).\n"
"Use 0.0.0.0 to expose on all network adapters."
),
)
port: Optional[int] = Field(
5000, description=("The port to host on (default: 5000).")
)
disable_auth: Optional[bool] = Field(
False,
description=(
"Disable HTTP token authentication with requests.\n"
"WARNING: This will make your instance vulnerable!\n"
"Turn on this option if you are ONLY connecting from localhost."
),
)
send_tracebacks: Optional[bool] = Field(
False,
description=(
"Send tracebacks over the API (default: False).\n"
"NOTE: Only enable this for debug purposes."
),
)
api_servers: Optional[List[Literal["OAI", "Kobold"]]] = Field(
["OAI"],
description=(
'Select API servers to enable (default: ["OAI"]).\n'
"Possible values: OAI, Kobold."
),
)
# TODO: Migrate config.yml to have the log_ prefix
# This is a breaking change.
class LoggingConfig(BaseConfigModel):
"""Options for logging"""
log_prompt: Optional[bool] = Field(
False,
description=("Enable prompt logging (default: False)."),
)
log_generation_params: Optional[bool] = Field(
False,
description=("Enable generation parameter logging (default: False)."),
)
log_requests: Optional[bool] = Field(
False,
description=(
"Enable request logging (default: False).\n"
"NOTE: Only use this for debugging!"
),
)
class ModelConfig(BaseConfigModel):
"""
Options for model overrides and loading
Please read the comments to understand how arguments are handled
between initial and API loads
"""
# TODO: convert this to a pathlib.path?
model_dir: str = Field(
"models",
description=(
"Directory to look for models (default: models).\n"
"Windows users, do NOT put this path in quotes!"
),
)
inline_model_loading: Optional[bool] = Field(
False,
description=(
"Allow direct loading of models "
"from a completion or chat completion request (default: False)."
),
)
use_dummy_models: Optional[bool] = Field(
False,
description=(
"Sends dummy model names when the models endpoint is queried.\n"
"Enable this if the client is looking for specific OAI models."
),
)
model_name: Optional[str] = Field(
None,
description=(
"An initial model to load.\n"
"Make sure the model is located in the model directory!\n"
"REQUIRED: This must be filled out to load a model on startup."
),
)
use_as_default: List[str] = Field(
default_factory=list,
description=(
"Names of args to use as a fallback for API load requests (default: []).\n"
"For example, if you always want cache_mode to be Q4 "
'instead of on the inital model load, add "cache_mode" to this array.\n'
"Example: ['max_seq_len', 'cache_mode']."
),
)
max_seq_len: Optional[int] = Field(
None,
description=(
"Max sequence length (default: Empty).\n"
"Fetched from the model's base sequence length in config.json by default."
),
ge=0,
)
override_base_seq_len: Optional[int] = Field(
None,
description=(
"Overrides base model context length (default: Empty).\n"
"WARNING: Don't set this unless you know what you're doing!\n"
"Again, do NOT use this for configuring context length, "
"use max_seq_len above ^"
),
ge=0,
)
tensor_parallel: Optional[bool] = Field(
False,
description=(
"Load model with tensor parallelism.\n"
"Falls back to autosplit if GPU split isn't provided.\n"
"This ignores the gpu_split_auto value."
),
)
gpu_split_auto: Optional[bool] = Field(
True,
description=(
"Automatically allocate resources to GPUs (default: True).\n"
"Not parsed for single GPU users."
),
)
autosplit_reserve: List[int] = Field(
[96],
description=(
"Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n"
"Represented as an array of MB per GPU."
),
)
gpu_split: List[float] = Field(
default_factory=list,
description=(
"An integer array of GBs of VRAM to split between GPUs (default: []).\n"
"Used with tensor parallelism."
),
)
rope_scale: Optional[float] = Field(
1.0,
description=(
"Rope scale (default: 1.0).\n"
"Same as compress_pos_emb.\n"
"Use if the model was trained on long context with rope.\n"
"Leave blank to pull the value from the model."
),
)
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
None,
description=(
"Rope alpha (default: None).\n"
'Same as alpha_value. Set to "auto" to auto-calculate.\n'
"Leaving this value blank will either pull from the model "
"or auto-calculate."
),
)
cache_mode: Optional[CACHE_SIZES] = Field(
"FP16",
description=(
"Enable different cache modes for VRAM savings (default: FP16).\n"
f"Possible values: {str(CACHE_SIZES)[15:-1]}."
),
)
cache_size: Optional[int] = Field(
None,
description=(
"Size of the prompt cache to allocate (default: max_seq_len).\n"
"Must be a multiple of 256 and can't be less than max_seq_len.\n"
"For CFG, set this to 2 * max_seq_len."
),
multiple_of=256,
gt=0,
)
chunk_size: Optional[int] = Field(
2048,
description=(
"Chunk size for prompt ingestion (default: 2048).\n"
"A lower value reduces VRAM usage but decreases ingestion speed.\n"
"NOTE: Effects vary depending on the model.\n"
"An ideal value is between 512 and 4096."
),
gt=0,
)
max_batch_size: Optional[int] = Field(
None,
description=(
"Set the maximum number of prompts to process at one time "
"(default: None/Automatic).\n"
"Automatically calculated if left blank.\n"
"NOTE: Only available for Nvidia ampere (30 series) and above GPUs."
),
ge=1,
)
prompt_template: Optional[str] = Field(
None,
description=(
"Set the prompt template for this model. (default: None)\n"
"If empty, attempts to look for the model's chat template.\n"
"If a model contains multiple templates in its tokenizer_config.json,\n"
"set prompt_template to the name of the template you want to use.\n"
"NOTE: Only works with chat completion message lists!"
),
)
num_experts_per_token: Optional[int] = Field(
None,
description=(
"Number of experts to use per token.\n"
"Fetched from the model's config.json if empty.\n"
"NOTE: For MoE models only.\n"
"WARNING: Don't set this unless you know what you're doing!"
),
ge=1,
)
fasttensors: Optional[bool] = Field(
False,
description=(
"Enables fasttensors to possibly increase model loading speeds "
"(default: False)."
),
)
_metadata: Metadata = PrivateAttr(Metadata())
model_config = ConfigDict(protected_namespaces=())
class DraftModelConfig(BaseConfigModel):
"""
Options for draft models (speculative decoding)
This will use more VRAM!
"""
# TODO: convert this to a pathlib.path?
draft_model_dir: Optional[str] = Field(
"models",
description=("Directory to look for draft models (default: models)"),
)
draft_model_name: Optional[str] = Field(
None,
description=(
"An initial draft model to load.\n"
"Ensure the model is in the model directory."
),
)
draft_rope_scale: Optional[float] = Field(
1.0,
description=(
"Rope scale for draft models (default: 1.0).\n"
"Same as compress_pos_emb.\n"
"Use if the draft model was trained on long context with rope."
),
)
draft_rope_alpha: Optional[float] = Field(
None,
description=(
"Rope alpha for draft models (default: None).\n"
'Same as alpha_value. Set to "auto" to auto-calculate.\n'
"Leaving this value blank will either pull from the model "
"or auto-calculate."
),
)
draft_cache_mode: Optional[CACHE_SIZES] = Field(
"FP16",
description=(
"Cache mode for draft models to save VRAM (default: FP16).\n"
f"Possible values: {str(CACHE_SIZES)[15:-1]}."
),
)
class LoraInstanceModel(BaseConfigModel):
"""Model representing an instance of a Lora."""
name: Optional[str] = None
scaling: float = Field(1.0, ge=0)
class LoraConfig(BaseConfigModel):
"""Options for Loras"""
# TODO: convert this to a pathlib.path?
lora_dir: Optional[str] = Field(
"loras", description=("Directory to look for LoRAs (default: loras).")
)
loras: Optional[List[LoraInstanceModel]] = Field(
None,
description=(
"List of LoRAs to load and associated scaling factors "
"(default scale: 1.0).\n"
"For the YAML file, add each entry as a YAML list:\n"
"- name: lora1\n"
" scaling: 1.0"
),
)
class EmbeddingsConfig(BaseConfigModel):
"""
Options for embedding models and loading.
NOTE: Embeddings requires the "extras" feature to be installed
Install it via "pip install .[extras]"
"""
# TODO: convert this to a pathlib.path?
embedding_model_dir: Optional[str] = Field(
"models",
description=("Directory to look for embedding models (default: models)."),
)
embeddings_device: Optional[Literal["cpu", "auto", "cuda"]] = Field(
"cpu",
description=(
"Device to load embedding models on (default: cpu).\n"
"Possible values: cpu, auto, cuda.\n"
"NOTE: It's recommended to load embedding models on the CPU.\n"
"If using an AMD GPU, set this value to 'cuda'."
),
)
embedding_model_name: Optional[str] = Field(
None,
description=("An initial embedding model to load on the infinity backend."),
)
class SamplingConfig(BaseConfigModel):
"""Options for Sampling"""
override_preset: Optional[str] = Field(
None,
description=(
"Select a sampler override preset (default: None).\n"
"Find this in the sampler-overrides folder.\n"
"This overrides default fallbacks for sampler values "
"that are passed to the API."
),
)
class DeveloperConfig(BaseConfigModel):
"""Options for development and experimentation"""
unsafe_launch: Optional[bool] = Field(
False,
description=(
"Skip Exllamav2 version check (default: False).\n"
"WARNING: It's highly recommended to update your dependencies rather "
"than enabling this flag."
),
)
disable_request_streaming: Optional[bool] = Field(
False, description=("Disable API request streaming (default: False).")
)
cuda_malloc_backend: Optional[bool] = Field(
False, description=("Enable the torch CUDA malloc backend (default: False).")
)
uvloop: Optional[bool] = Field(
False,
description=(
"Run asyncio using Uvloop or Winloop which can improve performance.\n"
"NOTE: It's recommended to enable this, but if something breaks "
"turn this off."
),
)
realtime_process_priority: Optional[bool] = Field(
False,
description=(
"Set process to use a higher priority.\n"
"For realtime process priority, run as administrator or sudo.\n"
"Otherwise, the priority will be set to high."
),
)
class TabbyConfigModel(BaseModel):
"""Base model for a TabbyConfig."""
config: Optional[ConfigOverrideConfig] = Field(
default_factory=ConfigOverrideConfig.model_construct
)
network: Optional[NetworkConfig] = Field(
default_factory=NetworkConfig.model_construct
)
logging: Optional[LoggingConfig] = Field(
default_factory=LoggingConfig.model_construct
)
model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct)
draft_model: Optional[DraftModelConfig] = Field(
default_factory=DraftModelConfig.model_construct
)
lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct)
embeddings: Optional[EmbeddingsConfig] = Field(
default_factory=EmbeddingsConfig.model_construct
)
sampling: Optional[SamplingConfig] = Field(
default_factory=SamplingConfig.model_construct
)
developer: Optional[DeveloperConfig] = Field(
default_factory=DeveloperConfig.model_construct
)
actions: Optional[UtilityActions] = Field(
default_factory=UtilityActions.model_construct
)
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())

View File

@@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str
"""Gets the download folder for the repo."""
if repo_type == "lora":
download_path = pathlib.Path(config.lora.get("lora_dir") or "loras")
download_path = pathlib.Path(config.lora.lora_dir)
else:
download_path = pathlib.Path(config.model.get("model_dir") or "models")
download_path = pathlib.Path(config.model.model_dir)
download_path = download_path / (folder_name or repo_id.split("/")[-1])
return download_path

View File

@@ -2,41 +2,19 @@
Functions for logging generation events.
"""
from pydantic import BaseModel
from loguru import logger
from typing import Dict, Optional
from typing import Optional
class GenLogPreferences(BaseModel):
"""Logging preference config."""
prompt: bool = False
generation_params: bool = False
# Global logging preferences constant
PREFERENCES = GenLogPreferences()
def update_from_dict(options_dict: Dict[str, bool]):
"""Wrapper to set the logging config for generations"""
global PREFERENCES
# Force bools on the dict
for value in options_dict.values():
if value is None:
value = False
PREFERENCES = GenLogPreferences.model_validate(options_dict)
from common.tabby_config import config
def broadcast_status():
"""Broadcasts the current logging status"""
enabled = []
if PREFERENCES.prompt:
if config.logging.log_prompt:
enabled.append("prompts")
if PREFERENCES.generation_params:
if config.logging.log_generation_params:
enabled.append("generation params")
if len(enabled) > 0:
@@ -47,13 +25,13 @@ def broadcast_status():
def log_generation_params(**kwargs):
"""Logs generation parameters to console."""
if PREFERENCES.generation_params:
if config.logging.log_generation_params:
logger.info(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
"""Logs the prompt to console."""
if PREFERENCES.prompt:
if config.logging.log_prompt:
formatted_prompt = "\n" + prompt
logger.info(
f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n"
@@ -66,7 +44,7 @@ def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
def log_response(request_id: str, response: str):
"""Logs the response to console."""
if PREFERENCES.prompt:
if config.logging.log_prompt:
formatted_response = "\n" + response
logger.info(
f"Response (ID: {request_id}): "

View File

@@ -11,7 +11,6 @@ from typing import Optional
from uuid import uuid4
from common.tabby_config import config
from common.utils import unwrap
class TabbyRequestErrorMessage(BaseModel):
@@ -39,7 +38,7 @@ def handle_request_error(message: str, exc_info: bool = True):
"""Log a request error to the console."""
trace = traceback.format_exc()
send_trace = unwrap(config.network.get("send_tracebacks"), False)
send_trace = config.network.send_tracebacks
error_message = TabbyRequestErrorMessage(
message=message, trace=trace if send_trace else None
@@ -134,7 +133,7 @@ def get_global_depends():
depends = [Depends(add_request_id)]
if config.logging.get("requests"):
if config.logging.log_requests:
depends.append(Depends(log_request))
return depends

View File

@@ -3,7 +3,7 @@
import aiofiles
import json
import pathlib
import yaml
from ruamel.yaml import YAML
from copy import deepcopy
from loguru import logger
from pydantic import AliasChoices, BaseModel, Field
@@ -416,7 +416,10 @@ async def overrides_from_file(preset_name: str):
overrides_container.selected_preset = preset_path.stem
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
contents = await raw_preset.read()
preset = yaml.safe_load(contents)
# Create a temporary YAML parser
yaml = YAML(typ="safe")
preset = yaml.load(contents)
overrides_from_dict(preset)
logger.info("Applied sampler overrides from file.")

View File

@@ -1,62 +1,94 @@
import yaml
import pathlib
from loguru import logger
from inspect import getdoc
from os import getenv
from textwrap import dedent
from typing import Optional
from common.utils import unwrap, merge_dicts
from loguru import logger
from pydantic import BaseModel
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap, CommentedSeq
from ruamel.yaml.scalarstring import PreservedScalarString
from common.config_models import BaseConfigModel, TabbyConfigModel
from common.utils import merge_dicts, unwrap
yaml = YAML(typ=["rt", "safe"])
class TabbyConfig:
"""Common config class for TabbyAPI. Loaded into sub-dictionaries from YAML file."""
# Sub-blocks of yaml
network: dict = {}
logging: dict = {}
model: dict = {}
draft_model: dict = {}
lora: dict = {}
sampling: dict = {}
developer: dict = {}
embeddings: dict = {}
class TabbyConfig(TabbyConfigModel):
# Persistent defaults
# TODO: make this pydantic?
model_defaults: dict = {}
def load(self, arguments: Optional[dict] = None):
"""Synchronously loads the global application config"""
# config is applied in order of items in the list
configs = [
self._from_file(pathlib.Path("config.yml")),
self._from_args(unwrap(arguments, {})),
]
arguments_dict = unwrap(arguments, {})
configs = [self._from_environment(), self._from_args(arguments_dict)]
# If actions aren't present, also look from the file
# TODO: Change logic if file loading requires actions in the future
if not arguments_dict.get("actions"):
configs.insert(0, self._from_file(pathlib.Path("config.yml")))
merged_config = merge_dicts(*configs)
self.network = unwrap(merged_config.get("network"), {})
self.logging = unwrap(merged_config.get("logging"), {})
self.model = unwrap(merged_config.get("model"), {})
self.draft_model = unwrap(self.model.get("draft"), {})
self.lora = unwrap(self.model.get("lora"), {})
self.sampling = unwrap(merged_config.get("sampling"), {})
self.developer = unwrap(merged_config.get("developer"), {})
self.embeddings = unwrap(merged_config.get("embeddings"), {})
# validate and update config
merged_config_model = TabbyConfigModel.model_validate(merged_config)
for field in TabbyConfigModel.model_fields.keys():
value = getattr(merged_config_model, field)
setattr(self, field, value)
# Set model defaults dict once to prevent on-demand reconstruction
default_keys = unwrap(self.model.get("use_as_default"), [])
for key in default_keys:
if key in self.model:
self.model_defaults[key] = config.model[key]
elif key in self.draft_model:
self.model_defaults[key] = config.draft_model[key]
# TODO: clean this up a bit
for field in self.model.use_as_default:
if hasattr(self.model, field):
self.model_defaults[field] = getattr(config.model, field)
elif hasattr(self.draft_model, field):
self.model_defaults[field] = getattr(config.draft_model, field)
else:
logger.error(
f"invalid item {field} in config option `model.use_as_default`"
)
def _from_file(self, config_path: pathlib.Path):
"""loads config from a given file path"""
legacy = False
cfg = {}
# try loading from file
try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
return unwrap(yaml.safe_load(config_file), {})
cfg = yaml.load(config_file)
# NOTE: Remove migration wrapper after a period of time
# load legacy config files
# Model config migration
model_cfg = unwrap(cfg.get("model"), {})
if model_cfg.get("draft"):
legacy = True
cfg["draft_model"] = model_cfg["draft"]
if model_cfg.get("lora"):
legacy = True
cfg["lora"] = model_cfg["lora"]
# Logging config migration
# This will catch the majority of legacy config files
logging_cfg = unwrap(cfg.get("logging"), {})
unmigrated_log_keys = [
key for key in logging_cfg.keys() if not key.startswith("log_")
]
if unmigrated_log_keys:
legacy = True
for key in unmigrated_log_keys:
cfg["logging"][f"log_{key}"] = cfg["logging"][key]
del cfg["logging"][key]
except FileNotFoundError:
logger.info(f"The '{config_path.name}' file cannot be found")
except Exception as exc:
@@ -65,25 +97,53 @@ class TabbyConfig:
f"the following error:\n\n{exc}"
)
# if no config file was loaded
return {}
if legacy:
logger.warning(
"Legacy config.yml file detected. Attempting auto-migration."
)
# Create a temporary base config model
new_cfg = TabbyConfigModel.model_validate(cfg)
try:
config_path.rename(f"{config_path}.bak")
generate_config_file(model=new_cfg, filename=config_path)
logger.info(
"Auto-migration successful. "
'The old configuration is stored in "config.yml.bak".'
)
except Exception as e:
logger.error(
f"Auto-migration failed because of: {e}\n\n"
"Reverted all changes.\n"
"Either fix your config.yml and restart or\n"
"Delete your old YAML file and create a new "
'config by copying "config_sample.yml" to "config.yml".'
)
# Restore the old config
config_path.unlink(missing_ok=True)
pathlib.Path(f"{config_path}.bak").rename(config_path)
# Don't use the partially loaded config
logger.warning("Starting with no config loaded.")
return {}
return unwrap(cfg, {})
def _from_args(self, args: dict):
"""loads config from the provided arguments"""
config = {}
config_override = unwrap(args.get("options", {}).get("config"))
config_override = args.get("options", {}).get("config", None)
if config_override:
logger.info("Config file override detected in args.")
config = self._from_file(pathlib.Path(config_override))
return config # Return early if loading from file
for key in ["network", "model", "logging", "developer", "embeddings"]:
for key in TabbyConfigModel.model_fields.keys():
override = args.get(key)
if override:
if key == "logging":
# Strip the "log_" prefix from logging keys if present
override = {k.replace("log_", ""): v for k, v in override.items()}
config[key] = override
return config
@@ -91,12 +151,116 @@ class TabbyConfig:
def _from_environment(self):
"""loads configuration from environment variables"""
# TODO: load config from environment variables
# this means that we can have host default to 0.0.0.0 in docker for example
# this would also mean that docker containers no longer require a non
# default config file to be used
pass
config = {}
for field_name in TabbyConfigModel.model_fields.keys():
section_config = {}
for sub_field_name in getattr(
TabbyConfigModel(), field_name
).model_fields.keys():
setting = getenv(f"TABBY_{field_name}_{sub_field_name}".upper(), None)
if setting is not None:
section_config[sub_field_name] = setting
config[field_name] = section_config
return config
# Create an empty instance of the config class
config: TabbyConfig = TabbyConfig()
def generate_config_file(
model: BaseModel = None,
filename: str = "config_sample.yml",
) -> None:
"""Creates a config.yml file from Pydantic models."""
schema = unwrap(model, TabbyConfigModel())
preamble = """
# Sample YAML file for configuration.
# Comment and uncomment values as needed.
# Every value has a default within the application.
# This file serves to be a drop in for config.yml
# Unless specified in the comments, DO NOT put these options in quotes!
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
"""
yaml_content = pydantic_model_to_yaml(schema)
with open(filename, "w") as f:
f.write(dedent(preamble).lstrip())
yaml.dump(yaml_content, f)
def pydantic_model_to_yaml(model: BaseModel, indentation: int = 0) -> CommentedMap:
"""
Recursively converts a Pydantic model into a CommentedMap,
with descriptions as comments in YAML.
"""
# Create a CommentedMap to hold the output data
yaml_data = CommentedMap()
# Loop through all fields in the model
iteration = 1
for field_name, field_info in model.model_fields.items():
# Get the inner pydantic model
value = getattr(model, field_name)
if isinstance(value, BaseConfigModel):
# If the field is another Pydantic model
if not value._metadata.include_in_config:
continue
yaml_data[field_name] = pydantic_model_to_yaml(
value, indentation=indentation + 2
)
comment = getdoc(value)
elif isinstance(value, list) and len(value) > 0:
# If the field is a list
yaml_list = CommentedSeq()
if isinstance(value[0], BaseModel):
# If the field is a list of Pydantic models
# Do not add comments for these items
for item in value:
yaml_list.append(
pydantic_model_to_yaml(item, indentation=indentation + 2)
)
else:
# If the field is a normal list, prefer the YAML flow style
yaml_list.fa.set_flow_style()
yaml_list += [
PreservedScalarString(element)
if isinstance(element, str)
else element
for element in value
]
yaml_data[field_name] = yaml_list
comment = field_info.description
else:
# Otherwise, just assign the value
yaml_data[field_name] = value
comment = field_info.description
if comment:
# Add a newline to every comment but the first one
if iteration != 1:
comment = f"\n{comment}"
yaml_data.yaml_set_comment_before_after_key(
field_name, before=comment, indent=indentation
)
# Increment the iteration counter
iteration += 1
return yaml_data

View File

@@ -1,7 +1,12 @@
"""Common utility functions"""
from types import NoneType
from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
def unwrap(wrapped, default=None):
T = TypeVar("T")
def unwrap(wrapped: Optional[T], default: T = None) -> T:
"""Unwrap function for Optionals."""
if wrapped is None:
return default
@@ -14,13 +19,13 @@ def coalesce(*args):
return next((arg for arg in args if arg is not None), None)
def prune_dict(input_dict):
def prune_dict(input_dict: Dict) -> Dict:
"""Trim out instances of None from a dictionary."""
return {k: v for k, v in input_dict.items() if v is not None}
def merge_dict(dict1, dict2):
def merge_dict(dict1: Dict, dict2: Dict) -> Dict:
"""Merge 2 dictionaries"""
for key, value in dict2.items():
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
@@ -30,7 +35,7 @@ def merge_dict(dict1, dict2):
return dict1
def merge_dicts(*dicts):
def merge_dicts(*dicts: Dict) -> Dict:
"""Merge an arbitrary amount of dictionaries"""
result = {}
for dictionary in dicts:
@@ -43,3 +48,33 @@ def flat_map(input_list):
"""Flattens a list of lists into a single list."""
return [item for sublist in input_list for item in sublist]
def is_list_type(type_hint) -> bool:
"""Checks if a type contains a list."""
if get_origin(type_hint) is list:
return True
# Recursively check for lists inside type arguments
type_args = get_args(type_hint)
if type_args:
return any(is_list_type(arg) for arg in type_args)
return False
def unwrap_optional_type(type_hint) -> Type:
"""
Unwrap Optional[type] annotations.
This is not the same as unwrap.
"""
if get_origin(type_hint) is Union:
args = get_args(type_hint)
if NoneType in args:
for arg in args:
if arg is not NoneType:
return arg
return type_hint

View File

@@ -1,232 +1,218 @@
# Sample YAML file for configuration.
# Comment and uncomment values as needed. Every value has a default within the application.
# This file serves to be a drop in for config.yml
# Unless specified in the comments, DO NOT put these options in quotes!
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.
# Options for networking
network:
# The IP to host on (default: 127.0.0.1).
# Use 0.0.0.0 to expose on all network adapters
host: 127.0.0.1
# The port to host on (default: 5000)
port: 5000
# Disable HTTP token authenticaion with requests
# WARNING: This will make your instance vulnerable!
# Turn on this option if you are ONLY connecting from localhost
disable_auth: False
# Send tracebacks over the API to clients (default: False)
# NOTE: Only enable this for debug purposes
send_tracebacks: False
# Select API servers to enable (default: ["OAI"])
# Possible values: OAI
api_servers: ["OAI"]
# Options for logging
logging:
# Enable prompt logging (default: False)
prompt: False
# Enable generation parameter logging (default: False)
generation_params: False
# Enable request logging (default: False)
# NOTE: Only use this for debugging!
requests: False
# Options for sampling
sampling:
# Override preset name. Find this in the sampler-overrides folder (default: None)
# This overrides default fallbacks for sampler values that are passed to the API
# Server-side overrides are NOT needed by default
# WARNING: Using this can result in a generation speed penalty
#override_preset:
# Options for development and experimentation
developer:
# Skips exllamav2 version check (default: False)
# It's highly recommended to update your dependencies rather than enabling this flag
# WARNING: Don't set this unless you know what you're doing!
#unsafe_launch: False
# Disable all request streaming (default: False)
# A kill switch for turning off SSE in the API server
#disable_request_streaming: False
# Enable the torch CUDA malloc backend (default: False)
# This can save a few MBs of VRAM, but has a risk of errors. Use at your own risk.
#cuda_malloc_backend: False
# Enable Uvloop or Winloop (default: False)
# Make the program utilize a faster async event loop which can improve performance
# NOTE: It's recommended to enable this, but if something breaks, turn this off.
#uvloop: False
# Set process to use a higher priority
# For realtime process priority, run as administrator or sudo
# Otherwise, the priority will be set to high
#realtime_process_priority: False
# Options for model overrides and loading
# Please read the comments to understand how arguments are handled between initial and API loads
model:
# Overrides the directory to look for models (default: models)
# Windows users, DO NOT put this path in quotes! This directory will be invalid otherwise.
model_dir: models
# Sends dummy model names when the models endpoint is queried
# Enable this if the program is looking for a specific OAI model
#use_dummy_models: False
# Allow direct loading of models from a completion or chat completion request
inline_model_loading: False
# An initial model to load. Make sure the model is located in the model directory!
# A model can be loaded later via the API.
# REQUIRED: This must be filled out to load a model on startup!
model_name:
# The below parameters only apply for initial loads
# All API based loads do NOT inherit these settings unless specified in use_as_default
# Names of args to use as a default fallback for API load requests (default: [])
# For example, if you always want cache_mode to be Q4 instead of on the inital model load,
# Add "cache_mode" to this array
# Ex. ["max_seq_len", "cache_mode"]
#use_as_default: []
# The below parameters apply only if model_name is set
# Max sequence length (default: Empty)
# Fetched from the model's base sequence length in config.json by default
#max_seq_len:
# Overrides base model context length (default: Empty)
# WARNING: Don't set this unless you know what you're doing!
# Again, do NOT use this for configuring context length, use max_seq_len above ^
# Only use this if the model's base sequence length in config.json is incorrect (ex. Mistral 7B)
#override_base_seq_len:
# Load model with tensor parallelism
# If a GPU split isn't provided, the TP loader will fallback to autosplit
# Enabling ignores the gpu_split_auto and autosplit_reserve values
#tensor_parallel: False
# Automatically allocate resources to GPUs (default: True)
# NOTE: Not parsed for single GPU users
#gpu_split_auto: True
# Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0)
# This is represented as an array of MB per GPU used
#autosplit_reserve: [96]
# An integer array of GBs of vram to split between GPUs (default: [])
# Used with tensor parallelism
# NOTE: Not parsed for single GPU users
#gpu_split: [20.6, 24]
# Rope scale (default: 1.0)
# Same thing as compress_pos_emb
# Only use if your model was trained on long context with rope (check config.json)
# Leave blank to pull the value from the model
#rope_scale: 1.0
# Rope alpha (default: 1.0)
# Same thing as alpha_value
# Set to "auto" to automatically calculate
# Leave blank to pull the value from the model
#rope_alpha: 1.0
# Enable different cache modes for VRAM savings (slight performance hit).
# Possible values FP16, Q8, Q6, Q4. (default: FP16)
#cache_mode: FP16
# Size of the prompt cache to allocate (default: max_seq_len)
# This must be a multiple of 256. A larger cache uses more VRAM, but allows for more prompts to be processed at once.
# NOTE: Cache size should not be less than max_seq_len.
# For CFG, set this to 2 * max_seq_len to make room for both positive and negative prompts.
#cache_size:
# Chunk size for prompt ingestion. A lower value reduces VRAM usage at the cost of ingestion speed (default: 2048)
# NOTE: Effects vary depending on the model. An ideal value is between 512 and 4096
#chunk_size: 2048
# Set the maximum amount of prompts to process at one time (default: None/Automatic)
# This will be automatically calculated if left blank.
# A max batch size of 1 processes prompts one at a time.
# NOTE: Only available for Nvidia ampere (30 series) and above GPUs
#max_batch_size:
# Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None)
# If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name
# of the template you want to use.
# NOTE: Only works with chat completion message lists!
#prompt_template:
# Number of experts to use PER TOKEN. Fetched from the model's config.json if not specified (default: Empty)
# WARNING: Don't set this unless you know what you're doing!
# NOTE: For MoE models (ex. Mixtral) only!
#num_experts_per_token:
# Enables fasttensors to possibly increase model loading speeds (default: False)
#fasttensors: true
# Options for draft models (speculative decoding). This will use more VRAM!
#draft:
# Overrides the directory to look for draft (default: models)
#draft_model_dir: models
# An initial draft model to load. Make sure this model is located in the model directory!
# A draft model can be loaded later via the API.
#draft_model_name: A model name
# The below parameters only apply for initial loads
# All API based loads do NOT inherit these settings unless specified in use_as_default
# Rope scale for draft models (default: 1.0)
# Same thing as compress_pos_emb
# Only use if your draft model was trained on long context with rope (check config.json)
#draft_rope_scale: 1.0
# Rope alpha for draft model (default: 1.0)
# Same thing as alpha_value
# Leave blank to automatically calculate alpha value
#draft_rope_alpha: 1.0
# Enable different draft model cache modes for VRAM savings (slight performance hit).
# Possible values FP16, Q8, Q6, Q4. (default: FP16)
#draft_cache_mode: FP16
# Options for loras
#lora:
# Overrides the directory to look for loras (default: loras)
#lora_dir: loras
# List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.
#loras:
#- name: lora1
# scaling: 1.0
# Options for embedding models and loading.
# NOTE: Embeddings requires the "extras" feature to be installed
# Install it via "pip install .[extras]"
embeddings:
# Overrides directory to look for embedding models (default: models)
embedding_model_dir: models
# Device to load embedding models on (default: cpu)
# Possible values: cpu, auto, cuda
# NOTE: It's recommended to load embedding models on the CPU.
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
embeddings_device: cpu
# The below parameters only apply for initial loads
# All API based loads do NOT inherit these settings unless specified in use_as_default
# An initial embedding model to load on the infinity backend (default: None)
embedding_model_name:
# Sample YAML file for configuration.
# Comment and uncomment values as needed.
# Every value has a default within the application.
# This file serves to be a drop in for config.yml
# Unless specified in the comments, DO NOT put these options in quotes!
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.
# Options for networking
network:
# The IP to host on (default: 127.0.0.1).
# Use 0.0.0.0 to expose on all network adapters.
host: 127.0.0.1
# The port to host on (default: 5000).
port: 5000
# Disable HTTP token authentication with requests.
# WARNING: This will make your instance vulnerable!
# Turn on this option if you are ONLY connecting from localhost.
disable_auth: false
# Send tracebacks over the API (default: False).
# NOTE: Only enable this for debug purposes.
send_tracebacks: false
# Select API servers to enable (default: ["OAI"]).
# Possible values: OAI, Kobold.
api_servers: ["OAI"]
# Options for logging
logging:
# Enable prompt logging (default: False).
log_prompt: false
# Enable generation parameter logging (default: False).
log_generation_params: false
# Enable request logging (default: False).
# NOTE: Only use this for debugging!
log_requests: false
# Options for model overrides and loading
# Please read the comments to understand how arguments are handled
# between initial and API loads
model:
# Directory to look for models (default: models).
# Windows users, do NOT put this path in quotes!
model_dir: models
# Allow direct loading of models from a completion or chat completion request (default: False).
inline_model_loading: false
# Sends dummy model names when the models endpoint is queried.
# Enable this if the client is looking for specific OAI models.
use_dummy_models: false
# An initial model to load.
# Make sure the model is located in the model directory!
# REQUIRED: This must be filled out to load a model on startup.
model_name:
# Names of args to use as a fallback for API load requests (default: []).
# For example, if you always want cache_mode to be Q4 instead of on the inital model load, add "cache_mode" to this array.
# Example: ['max_seq_len', 'cache_mode'].
use_as_default: []
# Max sequence length (default: Empty).
# Fetched from the model's base sequence length in config.json by default.
max_seq_len:
# Overrides base model context length (default: Empty).
# WARNING: Don't set this unless you know what you're doing!
# Again, do NOT use this for configuring context length, use max_seq_len above ^
override_base_seq_len:
# Load model with tensor parallelism.
# Falls back to autosplit if GPU split isn't provided.
# This ignores the gpu_split_auto value.
tensor_parallel: false
# Automatically allocate resources to GPUs (default: True).
# Not parsed for single GPU users.
gpu_split_auto: true
# Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).
# Represented as an array of MB per GPU.
autosplit_reserve: [96]
# An integer array of GBs of VRAM to split between GPUs (default: []).
# Used with tensor parallelism.
gpu_split: []
# Rope scale (default: 1.0).
# Same as compress_pos_emb.
# Use if the model was trained on long context with rope.
# Leave blank to pull the value from the model.
rope_scale: 1.0
# Rope alpha (default: None).
# Same as alpha_value. Set to "auto" to auto-calculate.
# Leaving this value blank will either pull from the model or auto-calculate.
rope_alpha:
# Enable different cache modes for VRAM savings (default: FP16).
# Possible values: 'FP16', 'Q8', 'Q6', 'Q4'.
cache_mode: FP16
# Size of the prompt cache to allocate (default: max_seq_len).
# Must be a multiple of 256 and can't be less than max_seq_len.
# For CFG, set this to 2 * max_seq_len.
cache_size:
# Chunk size for prompt ingestion (default: 2048).
# A lower value reduces VRAM usage but decreases ingestion speed.
# NOTE: Effects vary depending on the model.
# An ideal value is between 512 and 4096.
chunk_size: 2048
# Set the maximum number of prompts to process at one time (default: None/Automatic).
# Automatically calculated if left blank.
# NOTE: Only available for Nvidia ampere (30 series) and above GPUs.
max_batch_size:
# Set the prompt template for this model. (default: None)
# If empty, attempts to look for the model's chat template.
# If a model contains multiple templates in its tokenizer_config.json,
# set prompt_template to the name of the template you want to use.
# NOTE: Only works with chat completion message lists!
prompt_template:
# Number of experts to use per token.
# Fetched from the model's config.json if empty.
# NOTE: For MoE models only.
# WARNING: Don't set this unless you know what you're doing!
num_experts_per_token:
# Enables fasttensors to possibly increase model loading speeds (default: False).
fasttensors: false
# Options for draft models (speculative decoding)
# This will use more VRAM!
draft_model:
# Directory to look for draft models (default: models)
draft_model_dir: models
# An initial draft model to load.
# Ensure the model is in the model directory.
draft_model_name:
# Rope scale for draft models (default: 1.0).
# Same as compress_pos_emb.
# Use if the draft model was trained on long context with rope.
draft_rope_scale: 1.0
# Rope alpha for draft models (default: None).
# Same as alpha_value. Set to "auto" to auto-calculate.
# Leaving this value blank will either pull from the model or auto-calculate.
draft_rope_alpha:
# Cache mode for draft models to save VRAM (default: FP16).
# Possible values: 'FP16', 'Q8', 'Q6', 'Q4'.
draft_cache_mode: FP16
# Options for Loras
lora:
# Directory to look for LoRAs (default: loras).
lora_dir: loras
# List of LoRAs to load and associated scaling factors (default scale: 1.0).
# For the YAML file, add each entry as a YAML list:
# - name: lora1
# scaling: 1.0
loras:
# Options for embedding models and loading.
# NOTE: Embeddings requires the "extras" feature to be installed
# Install it via "pip install .[extras]"
embeddings:
# Directory to look for embedding models (default: models).
embedding_model_dir: models
# Device to load embedding models on (default: cpu).
# Possible values: cpu, auto, cuda.
# NOTE: It's recommended to load embedding models on the CPU.
# If using an AMD GPU, set this value to 'cuda'.
embeddings_device: cpu
# An initial embedding model to load on the infinity backend.
embedding_model_name:
# Options for Sampling
sampling:
# Select a sampler override preset (default: None).
# Find this in the sampler-overrides folder.
# This overrides default fallbacks for sampler values that are passed to the API.
override_preset:
# Options for development and experimentation
developer:
# Skip Exllamav2 version check (default: False).
# WARNING: It's highly recommended to update your dependencies rather than enabling this flag.
unsafe_launch: false
# Disable API request streaming (default: False).
disable_request_streaming: false
# Enable the torch CUDA malloc backend (default: False).
cuda_malloc_backend: false
# Run asyncio using Uvloop or Winloop which can improve performance.
# NOTE: It's recommended to enable this, but if something breaks turn this off.
uvloop: false
# Set process to use a higher priority.
# For realtime process priority, run as administrator or sudo.
# Otherwise, the priority will be set to high.
realtime_process_priority: false

View File

@@ -8,7 +8,6 @@ from common.auth import check_api_key
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config
from common.utils import unwrap
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
@@ -71,9 +70,7 @@ async def completion_request(
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
disable_request_streaming = unwrap(
config.developer.get("disable_request_streaming"), False
)
disable_request_streaming = config.developer.disable_request_streaming
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
@@ -135,9 +132,7 @@ async def chat_completion_request(
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
disable_request_streaming = unwrap(
config.developer.get("disable_request_streaming"), False
)
disable_request_streaming = config.developer.disable_request_streaming
if data.stream and not disable_request_streaming:
return EventSourceResponse(

View File

@@ -130,7 +130,7 @@ async def load_inline_model(model_name: str, request: Request):
raise HTTPException(401, error_message)
if not unwrap(config.model.get("inline_model_loading"), False):
if not config.model.inline_model_loading:
logger.warning(
f"Unable to switch model to {model_name} because "
'"inline_model_loading" is not True in config.yml.'
@@ -138,7 +138,7 @@ async def load_inline_model(model_name: str, request: Request):
return
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / model_name
# Model path doesn't exist

View File

@@ -62,17 +62,17 @@ async def list_models(request: Request) -> ModelList:
Requires an admin key to see all models.
"""
model_dir = unwrap(config.model.get("model_dir"), "models")
model_dir = config.model.model_dir
model_path = pathlib.Path(model_dir)
draft_model_dir = config.draft_model.get("draft_model_dir")
draft_model_dir = config.draft_model.draft_model_dir
if get_key_permission(request) == "admin":
models = get_model_list(model_path.resolve(), draft_model_dir)
else:
models = await get_current_model_list()
if unwrap(config.model.get("use_dummy_models"), False):
if config.model.use_dummy_models:
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
@@ -98,7 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
"""
if get_key_permission(request) == "admin":
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
draft_model_dir = config.draft_model.draft_model_dir
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
@@ -122,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / data.name
draft_model_path = None
@@ -135,7 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
draft_model_path = config.draft_model.draft_model_dir
if not model_path.exists():
error_message = handle_request_error(
@@ -192,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
"""
if get_key_permission(request) == "admin":
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
lora_path = pathlib.Path(config.lora.lora_dir)
loras = get_lora_list(lora_path.resolve())
else:
loras = get_active_loras()
@@ -227,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
raise HTTPException(400, error_message)
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
lora_dir = pathlib.Path(config.lora.lora_dir)
if not lora_dir.exists():
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
@@ -266,9 +266,7 @@ async def list_embedding_models(request: Request) -> ModelList:
"""
if get_key_permission(request) == "admin":
embedding_model_dir = unwrap(
config.embeddings.get("embedding_model_dir"), "models"
)
embedding_model_dir = config.embeddings.embedding_model_dir
embedding_model_path = pathlib.Path(embedding_model_dir)
models = get_model_list(embedding_model_path.resolve())
@@ -302,9 +300,7 @@ async def load_embedding_model(
raise HTTPException(400, error_message)
embedding_model_dir = pathlib.Path(
unwrap(config.embeddings.get("embedding_model_dir"), "models")
)
embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir)
embedding_model_path = embedding_model_dir / data.name
if not embedding_model_path.exists():

View File

@@ -4,9 +4,8 @@ from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Literal, Optional, Union
from common.gen_logging import GenLogPreferences
from common.config_models import LoggingConfig
from common.tabby_config import config
from common.utils import unwrap
class ModelCardParameters(BaseModel):
@@ -34,7 +33,7 @@ class ModelCard(BaseModel):
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
logging: Optional[GenLogPreferences] = None
logging: Optional[LoggingConfig] = None
parameters: Optional[ModelCardParameters] = None
@@ -118,11 +117,7 @@ class EmbeddingModelLoadRequest(BaseModel):
name: str
# Set default from the config
embeddings_device: Optional[str] = Field(
default_factory=lambda: unwrap(
config.embeddings.get("embeddings_device"), "cpu"
)
)
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)
class ModelLoadResponse(BaseModel):

View File

@@ -2,8 +2,9 @@ import pathlib
from asyncio import CancelledError
from typing import Optional
from common import gen_logging, model
from common import model
from common.networking import get_generator_error, handle_request_disconnect
from common.tabby_config import config
from common.utils import unwrap
from endpoints.core.types.model import (
ModelCard,
@@ -77,7 +78,7 @@ def get_current_model():
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
logging=config.logging,
)
if draft_model_params:

View File

@@ -8,7 +8,6 @@ from loguru import logger
from common.logger import UVICORN_LOG_CONFIG
from common.networking import get_global_depends
from common.tabby_config import config
from common.utils import unwrap
from endpoints.Kobold import router as KoboldRouter
from endpoints.OAI import router as OAIRouter
from endpoints.core.router import router as CoreRouter
@@ -36,7 +35,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
allow_headers=["*"],
)
api_servers = unwrap(config.network.get("api_servers"), [])
api_servers = config.network.api_servers
# Map for API id to server router
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}

65
main.py
View File

@@ -1,7 +1,6 @@
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import asyncio
import json
import os
import pathlib
import platform
@@ -12,12 +11,12 @@ from typing import Optional
from common import gen_logging, sampling, model
from common.args import convert_args_to_dict, init_argparser
from common.auth import load_auth_keys
from common.actions import branch_to_actions
from common.logger import setup_logger
from common.networking import is_port_in_use
from common.signals import signal_handler
from common.tabby_config import config
from common.utils import unwrap
from endpoints.server import export_openapi, start_api
from endpoints.server import start_api
from endpoints.utils import do_export_openapi
if not do_export_openapi:
@@ -27,8 +26,8 @@ if not do_export_openapi:
async def entrypoint_async():
"""Async entry function for program startup"""
host = unwrap(config.network.get("host"), "127.0.0.1")
port = unwrap(config.network.get("port"), 5000)
host = config.network.host
port = config.network.port
# Check if the port is available and attempt to bind a fallback
if is_port_in_use(port):
@@ -50,16 +49,12 @@ async def entrypoint_async():
port = fallback_port
# Initialize auth keys
await load_auth_keys(unwrap(config.network.get("disable_auth"), False))
# Override the generation log options if given
if config.logging:
gen_logging.update_from_dict(config.logging)
await load_auth_keys(config.network.disable_auth)
gen_logging.broadcast_status()
# Set sampler parameter overrides if provided
sampling_override_preset = config.sampling.get("override_preset")
sampling_override_preset = config.sampling.override_preset
if sampling_override_preset:
try:
await sampling.overrides_from_file(sampling_override_preset)
@@ -68,29 +63,38 @@ async def entrypoint_async():
# If an initial model name is specified, create a container
# and load the model
model_name = config.model.get("model_name")
model_name = config.model.model_name
if model_name:
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / model_name
await model.load_model(model_path.resolve(), **config.model)
# TODO: remove model_dump()
await model.load_model(
model_path.resolve(),
**config.model.model_dump(),
draft=config.draft_model.model_dump(),
)
# Load loras after loading the model
if config.lora.get("loras"):
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
await model.container.load_loras(lora_dir.resolve(), **config.lora)
if config.lora.loras:
lora_dir = pathlib.Path(config.lora.lora_dir)
# TODO: remove model_dump()
await model.container.load_loras(
lora_dir.resolve(), **config.lora.model_dump()
)
# If an initial embedding model name is specified, create a separate container
# and load the model
embedding_model_name = config.embeddings.get("embedding_model_name")
embedding_model_name = config.embeddings.embedding_model_name
if embedding_model_name:
embedding_model_path = pathlib.Path(
unwrap(config.embeddings.get("embedding_model_dir"), "models")
)
embedding_model_path = pathlib.Path(config.embeddings.embedding_model_dir)
embedding_model_path = embedding_model_path / embedding_model_name
try:
await model.load_embedding_model(embedding_model_path, **config.embeddings)
# TODO: remove model_dump()
await model.load_embedding_model(
embedding_model_path, **config.embeddings.model_dump()
)
except ImportError as ex:
logger.error(ex.msg)
@@ -112,18 +116,13 @@ def entrypoint(arguments: Optional[dict] = None):
# load config
config.load(arguments)
if do_export_openapi:
openapi_json = export_openapi()
with open("openapi.json", "w") as f:
f.write(json.dumps(openapi_json))
logger.info("Successfully wrote OpenAPI spec to openapi.json")
# branch to default paths if required
if branch_to_actions():
return
# Check exllamav2 version and give a descriptive error if it's too old
# Skip if launching unsafely
if unwrap(config.developer.get("unsafe_launch"), False):
if config.developer.unsafe_launch:
logger.warning(
"UNSAFE: Skipping ExllamaV2 version check.\n"
"If you aren't a developer, please keep this off!"
@@ -132,12 +131,12 @@ def entrypoint(arguments: Optional[dict] = None):
check_exllama_version()
# Enable CUDA malloc backend
if unwrap(config.developer.get("cuda_malloc_backend"), False):
if config.developer.cuda_malloc_backend:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.")
# Use Uvloop/Winloop
if unwrap(config.developer.get("uvloop"), False):
if config.developer.uvloop:
if platform.system() == "Windows":
from winloop import install
else:
@@ -149,7 +148,7 @@ def entrypoint(arguments: Optional[dict] = None):
logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.")
# Set the process priority
if unwrap(config.developer.get("realtime_process_priority"), False):
if config.developer.realtime_process_priority:
import psutil
current_process = psutil.Process(os.getpid())

View File

@@ -18,7 +18,7 @@ requires-python = ">=3.10"
dependencies = [
"fastapi-slim >= 0.110.0",
"pydantic >= 2.0.0",
"PyYAML",
"ruamel.yaml",
"rich",
"uvicorn >= 0.28.1",
"jinja2 >= 3.0.0",