mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Merge pull request #189 from SecretiveShell/pydantic-config
Update the config system to use Pydantic internally, bridging the gap between the YAML and args. YAML is still the preferred method to configure TabbyAPI, but args are no longer separately maintained.
This commit is contained in:
6
.github/workflows/pages.yml
vendored
6
.github/workflows/pages.yml
vendored
@@ -48,10 +48,8 @@ jobs:
|
||||
npm install @redocly/cli -g
|
||||
- name: Export OpenAPI docs
|
||||
run: |
|
||||
EXPORT_OPENAPI=1 python main.py
|
||||
mv openapi.json openapi-oai.json
|
||||
EXPORT_OPENAPI=1 python main.py --api-servers kobold
|
||||
mv openapi.json openapi-kobold.json
|
||||
python main.py --export-openapi true --openapi-export-path "openapi-oai.json" --api-servers OAI
|
||||
python main.py --export-openapi true --openapi-export-path "openapi-kobold.json" --api-servers kobold
|
||||
- name: Build and store Redocly site
|
||||
run: |
|
||||
mkdir static
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -213,3 +213,6 @@ openapi.json
|
||||
|
||||
# Infinity-emb cache
|
||||
.infinity_cache/
|
||||
|
||||
# Backup files
|
||||
*.bak
|
||||
|
||||
@@ -30,7 +30,7 @@ from itertools import zip_longest
|
||||
from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from backends.exllamav2.grammar import (
|
||||
ExLlamaV2Grammar,
|
||||
@@ -379,7 +379,10 @@ class ExllamaV2Container:
|
||||
override_config_path, "r", encoding="utf8"
|
||||
) as override_config_file:
|
||||
contents = await override_config_file.read()
|
||||
override_args = unwrap(yaml.safe_load(contents), {})
|
||||
|
||||
# Create a temporary YAML parser
|
||||
yaml = YAML(typ="safe")
|
||||
override_args = unwrap(yaml.load(contents), {})
|
||||
|
||||
# Merge draft overrides beforehand
|
||||
draft_override_args = unwrap(override_args.get("draft"), {})
|
||||
|
||||
27
common/actions.py
Normal file
27
common/actions.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import json
|
||||
from loguru import logger
|
||||
|
||||
from common.tabby_config import config, generate_config_file
|
||||
from endpoints.server import export_openapi
|
||||
|
||||
|
||||
def branch_to_actions() -> bool:
|
||||
"""Checks if a optional action needs to be run."""
|
||||
|
||||
if config.actions.export_openapi:
|
||||
openapi_json = export_openapi()
|
||||
|
||||
with open(config.actions.openapi_export_path, "w") as f:
|
||||
f.write(json.dumps(openapi_json))
|
||||
logger.info(
|
||||
"Successfully wrote OpenAPI spec to "
|
||||
+ f"{config.actions.openapi_export_path}"
|
||||
)
|
||||
elif config.actions.export_config:
|
||||
generate_config_file(filename=config.actions.config_export_path)
|
||||
else:
|
||||
# did not branch
|
||||
return False
|
||||
|
||||
# branched and ran an action
|
||||
return True
|
||||
276
common/args.py
276
common/args.py
@@ -1,56 +1,60 @@
|
||||
"""Argparser for overriding config values"""
|
||||
|
||||
import argparse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.config_models import TabbyConfigModel
|
||||
from common.utils import is_list_type, unwrap_optional_type
|
||||
|
||||
|
||||
def str_to_bool(value):
|
||||
"""Converts a string into a boolean value"""
|
||||
|
||||
if value.lower() in {"false", "f", "0", "no", "n"}:
|
||||
return False
|
||||
elif value.lower() in {"true", "t", "1", "yes", "y"}:
|
||||
return True
|
||||
raise ValueError(f"{value} is not a valid boolean value")
|
||||
|
||||
|
||||
def argument_with_auto(value):
|
||||
def add_field_to_group(group, field_name, field_type, field) -> None:
|
||||
"""
|
||||
Argparse type wrapper for any argument that has an automatic option.
|
||||
|
||||
Ex. rope_alpha
|
||||
Adds a Pydantic field to an argparse argument group.
|
||||
"""
|
||||
|
||||
if value == "auto":
|
||||
return "auto"
|
||||
kwargs = {
|
||||
"help": field.description if field.description else "No description available",
|
||||
}
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError as ex:
|
||||
raise argparse.ArgumentTypeError(
|
||||
'This argument only takes a type of float or "auto"'
|
||||
) from ex
|
||||
# If the inner type contains a list, specify argparse as such
|
||||
if is_list_type(field_type):
|
||||
kwargs["nargs"] = "+"
|
||||
|
||||
group.add_argument(f"--{field_name}", **kwargs)
|
||||
|
||||
|
||||
def init_argparser():
|
||||
"""Creates an argument parser that any function can use"""
|
||||
def init_argparser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Initializes an argparse parser based on a Pydantic config schema.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog="NOTE: These args serve to override parts of the config. "
|
||||
+ "It's highly recommended to edit config.yml for all options and "
|
||||
+ "better descriptions!"
|
||||
)
|
||||
add_network_args(parser)
|
||||
add_model_args(parser)
|
||||
add_embeddings_args(parser)
|
||||
add_logging_args(parser)
|
||||
add_developer_args(parser)
|
||||
add_sampling_args(parser)
|
||||
add_config_args(parser)
|
||||
parser = argparse.ArgumentParser(description="TabbyAPI server")
|
||||
|
||||
# Loop through each top-level field in the config
|
||||
for field_name, field_info in TabbyConfigModel.model_fields.items():
|
||||
field_type = unwrap_optional_type(field_info.annotation)
|
||||
group = parser.add_argument_group(
|
||||
field_name, description=f"Arguments for {field_name}"
|
||||
)
|
||||
|
||||
# Check if the field_type is a Pydantic model
|
||||
if issubclass(field_type, BaseModel):
|
||||
for sub_field_name, sub_field_info in field_type.model_fields.items():
|
||||
sub_field_name = sub_field_name.replace("_", "-")
|
||||
sub_field_type = sub_field_info.annotation
|
||||
add_field_to_group(
|
||||
group, sub_field_name, sub_field_type, sub_field_info
|
||||
)
|
||||
else:
|
||||
field_name = field_name.replace("_", "-")
|
||||
group.add_argument(f"--{field_name}", help=f"Argument for {field_name}")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
def convert_args_to_dict(
|
||||
args: argparse.Namespace, parser: argparse.ArgumentParser
|
||||
) -> dict:
|
||||
"""Broad conversion of surface level arg groups to dictionaries"""
|
||||
|
||||
arg_groups = {}
|
||||
@@ -64,201 +68,3 @@ def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentPars
|
||||
arg_groups[group.title] = group_dict
|
||||
|
||||
return arg_groups
|
||||
|
||||
|
||||
def add_config_args(parser: argparse.ArgumentParser):
|
||||
"""Adds config arguments"""
|
||||
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="Path to an overriding config.yml file"
|
||||
)
|
||||
|
||||
|
||||
def add_network_args(parser: argparse.ArgumentParser):
|
||||
"""Adds networking arguments"""
|
||||
|
||||
network_group = parser.add_argument_group("network")
|
||||
network_group.add_argument("--host", type=str, help="The IP to host on")
|
||||
network_group.add_argument("--port", type=int, help="The port to host on")
|
||||
network_group.add_argument(
|
||||
"--disable-auth",
|
||||
type=str_to_bool,
|
||||
help="Disable HTTP token authenticaion with requests",
|
||||
)
|
||||
network_group.add_argument(
|
||||
"--send-tracebacks",
|
||||
type=str_to_bool,
|
||||
help="Decide whether to send error tracebacks over the API",
|
||||
)
|
||||
network_group.add_argument(
|
||||
"--api-servers",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="API servers to enable. Options: (OAI, Kobold)",
|
||||
)
|
||||
|
||||
|
||||
def add_model_args(parser: argparse.ArgumentParser):
|
||||
"""Adds model arguments"""
|
||||
|
||||
model_group = parser.add_argument_group("model")
|
||||
model_group.add_argument(
|
||||
"--model-dir", type=str, help="Overrides the directory to look for models"
|
||||
)
|
||||
model_group.add_argument("--model-name", type=str, help="An initial model to load")
|
||||
model_group.add_argument(
|
||||
"--use-dummy-models",
|
||||
type=str_to_bool,
|
||||
help="Add dummy OAI model names for API queries",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--use-as-default",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Names of args to use as a default fallback for API load requests ",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max-seq-len", type=int, help="Override the maximum model sequence length"
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--override-base-seq-len",
|
||||
type=str_to_bool,
|
||||
help="Overrides base model context length",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--tensor-parallel",
|
||||
type=str_to_bool,
|
||||
help="Use tensor parallelism to load models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split-auto",
|
||||
type=str_to_bool,
|
||||
help="Automatically allocate resources to GPUs",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--autosplit-reserve",
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="Reserve VRAM used for autosplit loading (in MBs) ",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split",
|
||||
type=float,
|
||||
nargs="+",
|
||||
help="An integer array of GBs of vram to split between GPUs. "
|
||||
+ "Ignored if gpu_split_auto is true",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--rope-alpha",
|
||||
type=argument_with_auto,
|
||||
help="Sets rope_alpha for NTK",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--cache-mode",
|
||||
type=str,
|
||||
help="Set the quantization level of the K/V cache. Options: (FP16, Q8, Q6, Q4)",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--cache-size",
|
||||
type=int,
|
||||
help="The size of the prompt cache (in number of tokens) to allocate",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
help="Chunk size for prompt ingestion",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max-batch-size",
|
||||
type=int,
|
||||
help="Maximum amount of prompts to process at one time",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--prompt-template",
|
||||
type=str,
|
||||
help="Set the jinja2 prompt template for chat completions",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--num-experts-per-token",
|
||||
type=int,
|
||||
help="Number of experts to use per token in MoE models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--fasttensors",
|
||||
type=str_to_bool,
|
||||
help="Possibly increases model loading speeds",
|
||||
)
|
||||
|
||||
|
||||
def add_logging_args(parser: argparse.ArgumentParser):
|
||||
"""Adds logging arguments"""
|
||||
|
||||
logging_group = parser.add_argument_group("logging")
|
||||
logging_group.add_argument(
|
||||
"--log-prompt", type=str_to_bool, help="Enable prompt logging"
|
||||
)
|
||||
logging_group.add_argument(
|
||||
"--log-generation-params",
|
||||
type=str_to_bool,
|
||||
help="Enable generation parameter logging",
|
||||
)
|
||||
logging_group.add_argument(
|
||||
"--log-requests",
|
||||
type=str_to_bool,
|
||||
help="Enable request logging",
|
||||
)
|
||||
|
||||
|
||||
def add_developer_args(parser: argparse.ArgumentParser):
|
||||
"""Adds developer-specific arguments"""
|
||||
|
||||
developer_group = parser.add_argument_group("developer")
|
||||
developer_group.add_argument(
|
||||
"--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check"
|
||||
)
|
||||
developer_group.add_argument(
|
||||
"--disable-request-streaming",
|
||||
type=str_to_bool,
|
||||
help="Disables API request streaming",
|
||||
)
|
||||
developer_group.add_argument(
|
||||
"--cuda-malloc-backend",
|
||||
type=str_to_bool,
|
||||
help="Runs with the pytorch CUDA malloc backend",
|
||||
)
|
||||
developer_group.add_argument(
|
||||
"--uvloop",
|
||||
type=str_to_bool,
|
||||
help="Run asyncio using Uvloop or Winloop",
|
||||
)
|
||||
|
||||
|
||||
def add_sampling_args(parser: argparse.ArgumentParser):
|
||||
"""Adds sampling-specific arguments"""
|
||||
|
||||
sampling_group = parser.add_argument_group("sampling")
|
||||
sampling_group.add_argument(
|
||||
"--override-preset", type=str, help="Select a sampler override preset"
|
||||
)
|
||||
|
||||
|
||||
def add_embeddings_args(parser: argparse.ArgumentParser):
|
||||
"""Adds arguments specific to embeddings"""
|
||||
|
||||
embeddings_group = parser.add_argument_group("embeddings")
|
||||
embeddings_group.add_argument(
|
||||
"--embedding-model-dir",
|
||||
type=str,
|
||||
help="Overrides the directory to look for models",
|
||||
)
|
||||
embeddings_group.add_argument(
|
||||
"--embedding-model-name", type=str, help="An initial model to load"
|
||||
)
|
||||
embeddings_group.add_argument(
|
||||
"--embeddings-device",
|
||||
type=str,
|
||||
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
|
||||
)
|
||||
|
||||
@@ -4,8 +4,9 @@ application, it should be fine.
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import io
|
||||
import secrets
|
||||
import yaml
|
||||
from ruamel.yaml import YAML
|
||||
from fastapi import Header, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
@@ -57,10 +58,13 @@ async def load_auth_keys(disable_from_config: bool):
|
||||
|
||||
return
|
||||
|
||||
# Create a temporary YAML parser
|
||||
yaml = YAML(typ=["rt", "safe"])
|
||||
|
||||
try:
|
||||
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
contents = await auth_file.read()
|
||||
auth_keys_dict = yaml.safe_load(contents)
|
||||
auth_keys_dict = yaml.load(contents)
|
||||
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
|
||||
except FileNotFoundError:
|
||||
new_auth_keys = AuthKeys(
|
||||
@@ -69,10 +73,10 @@ async def load_auth_keys(disable_from_config: bool):
|
||||
AUTH_KEYS = new_auth_keys
|
||||
|
||||
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
new_auth_yaml = yaml.safe_dump(
|
||||
AUTH_KEYS.model_dump(), default_flow_style=False
|
||||
)
|
||||
await auth_file.write(new_auth_yaml)
|
||||
string_stream = io.StringIO()
|
||||
yaml.dump(AUTH_KEYS.model_dump(), string_stream)
|
||||
|
||||
await auth_file.write(string_stream.getvalue())
|
||||
|
||||
logger.info(
|
||||
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
||||
|
||||
469
common/config_models.py
Normal file
469
common/config_models.py
Normal file
@@ -0,0 +1,469 @@
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
|
||||
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
|
||||
|
||||
|
||||
class Metadata(BaseModel):
|
||||
"""metadata model for config options"""
|
||||
|
||||
include_in_config: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class BaseConfigModel(BaseModel):
|
||||
"""Base model for config models with added metadata"""
|
||||
|
||||
_metadata: Metadata = PrivateAttr(Metadata())
|
||||
|
||||
|
||||
class ConfigOverrideConfig(BaseConfigModel):
|
||||
"""Model for overriding a provided config file."""
|
||||
|
||||
# TODO: convert this to a pathlib.path?
|
||||
config: Optional[str] = Field(
|
||||
None, description=("Path to an overriding config.yml file")
|
||||
)
|
||||
|
||||
_metadata: Metadata = PrivateAttr(Metadata(include_in_config=False))
|
||||
|
||||
|
||||
class UtilityActions(BaseConfigModel):
|
||||
"""Model used for arg actions."""
|
||||
|
||||
# YAML export options
|
||||
export_config: Optional[str] = Field(
|
||||
None, description="generate a template config file"
|
||||
)
|
||||
config_export_path: Optional[Path] = Field(
|
||||
"config_sample.yml", description="path to export configuration file to"
|
||||
)
|
||||
|
||||
# OpenAPI JSON export options
|
||||
export_openapi: Optional[bool] = Field(
|
||||
False, description="export openapi schema files"
|
||||
)
|
||||
openapi_export_path: Optional[Path] = Field(
|
||||
"openapi.json", description="path to export openapi schema to"
|
||||
)
|
||||
|
||||
_metadata: Metadata = PrivateAttr(Metadata(include_in_config=False))
|
||||
|
||||
|
||||
class NetworkConfig(BaseConfigModel):
|
||||
"""Options for networking"""
|
||||
|
||||
host: Optional[str] = Field(
|
||||
"127.0.0.1",
|
||||
description=(
|
||||
"The IP to host on (default: 127.0.0.1).\n"
|
||||
"Use 0.0.0.0 to expose on all network adapters."
|
||||
),
|
||||
)
|
||||
port: Optional[int] = Field(
|
||||
5000, description=("The port to host on (default: 5000).")
|
||||
)
|
||||
disable_auth: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Disable HTTP token authentication with requests.\n"
|
||||
"WARNING: This will make your instance vulnerable!\n"
|
||||
"Turn on this option if you are ONLY connecting from localhost."
|
||||
),
|
||||
)
|
||||
send_tracebacks: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Send tracebacks over the API (default: False).\n"
|
||||
"NOTE: Only enable this for debug purposes."
|
||||
),
|
||||
)
|
||||
api_servers: Optional[List[Literal["OAI", "Kobold"]]] = Field(
|
||||
["OAI"],
|
||||
description=(
|
||||
'Select API servers to enable (default: ["OAI"]).\n'
|
||||
"Possible values: OAI, Kobold."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# TODO: Migrate config.yml to have the log_ prefix
|
||||
# This is a breaking change.
|
||||
class LoggingConfig(BaseConfigModel):
|
||||
"""Options for logging"""
|
||||
|
||||
log_prompt: Optional[bool] = Field(
|
||||
False,
|
||||
description=("Enable prompt logging (default: False)."),
|
||||
)
|
||||
log_generation_params: Optional[bool] = Field(
|
||||
False,
|
||||
description=("Enable generation parameter logging (default: False)."),
|
||||
)
|
||||
log_requests: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Enable request logging (default: False).\n"
|
||||
"NOTE: Only use this for debugging!"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ModelConfig(BaseConfigModel):
|
||||
"""
|
||||
Options for model overrides and loading
|
||||
Please read the comments to understand how arguments are handled
|
||||
between initial and API loads
|
||||
"""
|
||||
|
||||
# TODO: convert this to a pathlib.path?
|
||||
model_dir: str = Field(
|
||||
"models",
|
||||
description=(
|
||||
"Directory to look for models (default: models).\n"
|
||||
"Windows users, do NOT put this path in quotes!"
|
||||
),
|
||||
)
|
||||
inline_model_loading: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Allow direct loading of models "
|
||||
"from a completion or chat completion request (default: False)."
|
||||
),
|
||||
)
|
||||
use_dummy_models: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Sends dummy model names when the models endpoint is queried.\n"
|
||||
"Enable this if the client is looking for specific OAI models."
|
||||
),
|
||||
)
|
||||
model_name: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"An initial model to load.\n"
|
||||
"Make sure the model is located in the model directory!\n"
|
||||
"REQUIRED: This must be filled out to load a model on startup."
|
||||
),
|
||||
)
|
||||
use_as_default: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Names of args to use as a fallback for API load requests (default: []).\n"
|
||||
"For example, if you always want cache_mode to be Q4 "
|
||||
'instead of on the inital model load, add "cache_mode" to this array.\n'
|
||||
"Example: ['max_seq_len', 'cache_mode']."
|
||||
),
|
||||
)
|
||||
max_seq_len: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Max sequence length (default: Empty).\n"
|
||||
"Fetched from the model's base sequence length in config.json by default."
|
||||
),
|
||||
ge=0,
|
||||
)
|
||||
override_base_seq_len: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Overrides base model context length (default: Empty).\n"
|
||||
"WARNING: Don't set this unless you know what you're doing!\n"
|
||||
"Again, do NOT use this for configuring context length, "
|
||||
"use max_seq_len above ^"
|
||||
),
|
||||
ge=0,
|
||||
)
|
||||
tensor_parallel: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Load model with tensor parallelism.\n"
|
||||
"Falls back to autosplit if GPU split isn't provided.\n"
|
||||
"This ignores the gpu_split_auto value."
|
||||
),
|
||||
)
|
||||
gpu_split_auto: Optional[bool] = Field(
|
||||
True,
|
||||
description=(
|
||||
"Automatically allocate resources to GPUs (default: True).\n"
|
||||
"Not parsed for single GPU users."
|
||||
),
|
||||
)
|
||||
autosplit_reserve: List[int] = Field(
|
||||
[96],
|
||||
description=(
|
||||
"Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n"
|
||||
"Represented as an array of MB per GPU."
|
||||
),
|
||||
)
|
||||
gpu_split: List[float] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"An integer array of GBs of VRAM to split between GPUs (default: []).\n"
|
||||
"Used with tensor parallelism."
|
||||
),
|
||||
)
|
||||
rope_scale: Optional[float] = Field(
|
||||
1.0,
|
||||
description=(
|
||||
"Rope scale (default: 1.0).\n"
|
||||
"Same as compress_pos_emb.\n"
|
||||
"Use if the model was trained on long context with rope.\n"
|
||||
"Leave blank to pull the value from the model."
|
||||
),
|
||||
)
|
||||
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Rope alpha (default: None).\n"
|
||||
'Same as alpha_value. Set to "auto" to auto-calculate.\n'
|
||||
"Leaving this value blank will either pull from the model "
|
||||
"or auto-calculate."
|
||||
),
|
||||
)
|
||||
cache_mode: Optional[CACHE_SIZES] = Field(
|
||||
"FP16",
|
||||
description=(
|
||||
"Enable different cache modes for VRAM savings (default: FP16).\n"
|
||||
f"Possible values: {str(CACHE_SIZES)[15:-1]}."
|
||||
),
|
||||
)
|
||||
cache_size: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Size of the prompt cache to allocate (default: max_seq_len).\n"
|
||||
"Must be a multiple of 256 and can't be less than max_seq_len.\n"
|
||||
"For CFG, set this to 2 * max_seq_len."
|
||||
),
|
||||
multiple_of=256,
|
||||
gt=0,
|
||||
)
|
||||
chunk_size: Optional[int] = Field(
|
||||
2048,
|
||||
description=(
|
||||
"Chunk size for prompt ingestion (default: 2048).\n"
|
||||
"A lower value reduces VRAM usage but decreases ingestion speed.\n"
|
||||
"NOTE: Effects vary depending on the model.\n"
|
||||
"An ideal value is between 512 and 4096."
|
||||
),
|
||||
gt=0,
|
||||
)
|
||||
max_batch_size: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Set the maximum number of prompts to process at one time "
|
||||
"(default: None/Automatic).\n"
|
||||
"Automatically calculated if left blank.\n"
|
||||
"NOTE: Only available for Nvidia ampere (30 series) and above GPUs."
|
||||
),
|
||||
ge=1,
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Set the prompt template for this model. (default: None)\n"
|
||||
"If empty, attempts to look for the model's chat template.\n"
|
||||
"If a model contains multiple templates in its tokenizer_config.json,\n"
|
||||
"set prompt_template to the name of the template you want to use.\n"
|
||||
"NOTE: Only works with chat completion message lists!"
|
||||
),
|
||||
)
|
||||
num_experts_per_token: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Number of experts to use per token.\n"
|
||||
"Fetched from the model's config.json if empty.\n"
|
||||
"NOTE: For MoE models only.\n"
|
||||
"WARNING: Don't set this unless you know what you're doing!"
|
||||
),
|
||||
ge=1,
|
||||
)
|
||||
fasttensors: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Enables fasttensors to possibly increase model loading speeds "
|
||||
"(default: False)."
|
||||
),
|
||||
)
|
||||
|
||||
_metadata: Metadata = PrivateAttr(Metadata())
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class DraftModelConfig(BaseConfigModel):
|
||||
"""
|
||||
Options for draft models (speculative decoding)
|
||||
This will use more VRAM!
|
||||
"""
|
||||
|
||||
# TODO: convert this to a pathlib.path?
|
||||
draft_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
description=("Directory to look for draft models (default: models)"),
|
||||
)
|
||||
draft_model_name: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"An initial draft model to load.\n"
|
||||
"Ensure the model is in the model directory."
|
||||
),
|
||||
)
|
||||
draft_rope_scale: Optional[float] = Field(
|
||||
1.0,
|
||||
description=(
|
||||
"Rope scale for draft models (default: 1.0).\n"
|
||||
"Same as compress_pos_emb.\n"
|
||||
"Use if the draft model was trained on long context with rope."
|
||||
),
|
||||
)
|
||||
draft_rope_alpha: Optional[float] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Rope alpha for draft models (default: None).\n"
|
||||
'Same as alpha_value. Set to "auto" to auto-calculate.\n'
|
||||
"Leaving this value blank will either pull from the model "
|
||||
"or auto-calculate."
|
||||
),
|
||||
)
|
||||
draft_cache_mode: Optional[CACHE_SIZES] = Field(
|
||||
"FP16",
|
||||
description=(
|
||||
"Cache mode for draft models to save VRAM (default: FP16).\n"
|
||||
f"Possible values: {str(CACHE_SIZES)[15:-1]}."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LoraInstanceModel(BaseConfigModel):
|
||||
"""Model representing an instance of a Lora."""
|
||||
|
||||
name: Optional[str] = None
|
||||
scaling: float = Field(1.0, ge=0)
|
||||
|
||||
|
||||
class LoraConfig(BaseConfigModel):
|
||||
"""Options for Loras"""
|
||||
|
||||
# TODO: convert this to a pathlib.path?
|
||||
lora_dir: Optional[str] = Field(
|
||||
"loras", description=("Directory to look for LoRAs (default: loras).")
|
||||
)
|
||||
loras: Optional[List[LoraInstanceModel]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"List of LoRAs to load and associated scaling factors "
|
||||
"(default scale: 1.0).\n"
|
||||
"For the YAML file, add each entry as a YAML list:\n"
|
||||
"- name: lora1\n"
|
||||
" scaling: 1.0"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingsConfig(BaseConfigModel):
|
||||
"""
|
||||
Options for embedding models and loading.
|
||||
NOTE: Embeddings requires the "extras" feature to be installed
|
||||
Install it via "pip install .[extras]"
|
||||
"""
|
||||
|
||||
# TODO: convert this to a pathlib.path?
|
||||
embedding_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
description=("Directory to look for embedding models (default: models)."),
|
||||
)
|
||||
embeddings_device: Optional[Literal["cpu", "auto", "cuda"]] = Field(
|
||||
"cpu",
|
||||
description=(
|
||||
"Device to load embedding models on (default: cpu).\n"
|
||||
"Possible values: cpu, auto, cuda.\n"
|
||||
"NOTE: It's recommended to load embedding models on the CPU.\n"
|
||||
"If using an AMD GPU, set this value to 'cuda'."
|
||||
),
|
||||
)
|
||||
embedding_model_name: Optional[str] = Field(
|
||||
None,
|
||||
description=("An initial embedding model to load on the infinity backend."),
|
||||
)
|
||||
|
||||
|
||||
class SamplingConfig(BaseConfigModel):
|
||||
"""Options for Sampling"""
|
||||
|
||||
override_preset: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Select a sampler override preset (default: None).\n"
|
||||
"Find this in the sampler-overrides folder.\n"
|
||||
"This overrides default fallbacks for sampler values "
|
||||
"that are passed to the API."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DeveloperConfig(BaseConfigModel):
|
||||
"""Options for development and experimentation"""
|
||||
|
||||
unsafe_launch: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Skip Exllamav2 version check (default: False).\n"
|
||||
"WARNING: It's highly recommended to update your dependencies rather "
|
||||
"than enabling this flag."
|
||||
),
|
||||
)
|
||||
disable_request_streaming: Optional[bool] = Field(
|
||||
False, description=("Disable API request streaming (default: False).")
|
||||
)
|
||||
cuda_malloc_backend: Optional[bool] = Field(
|
||||
False, description=("Enable the torch CUDA malloc backend (default: False).")
|
||||
)
|
||||
uvloop: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Run asyncio using Uvloop or Winloop which can improve performance.\n"
|
||||
"NOTE: It's recommended to enable this, but if something breaks "
|
||||
"turn this off."
|
||||
),
|
||||
)
|
||||
realtime_process_priority: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Set process to use a higher priority.\n"
|
||||
"For realtime process priority, run as administrator or sudo.\n"
|
||||
"Otherwise, the priority will be set to high."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TabbyConfigModel(BaseModel):
|
||||
"""Base model for a TabbyConfig."""
|
||||
|
||||
config: Optional[ConfigOverrideConfig] = Field(
|
||||
default_factory=ConfigOverrideConfig.model_construct
|
||||
)
|
||||
network: Optional[NetworkConfig] = Field(
|
||||
default_factory=NetworkConfig.model_construct
|
||||
)
|
||||
logging: Optional[LoggingConfig] = Field(
|
||||
default_factory=LoggingConfig.model_construct
|
||||
)
|
||||
model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct)
|
||||
draft_model: Optional[DraftModelConfig] = Field(
|
||||
default_factory=DraftModelConfig.model_construct
|
||||
)
|
||||
lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct)
|
||||
embeddings: Optional[EmbeddingsConfig] = Field(
|
||||
default_factory=EmbeddingsConfig.model_construct
|
||||
)
|
||||
sampling: Optional[SamplingConfig] = Field(
|
||||
default_factory=SamplingConfig.model_construct
|
||||
)
|
||||
developer: Optional[DeveloperConfig] = Field(
|
||||
default_factory=DeveloperConfig.model_construct
|
||||
)
|
||||
actions: Optional[UtilityActions] = Field(
|
||||
default_factory=UtilityActions.model_construct
|
||||
)
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
|
||||
@@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str
|
||||
"""Gets the download folder for the repo."""
|
||||
|
||||
if repo_type == "lora":
|
||||
download_path = pathlib.Path(config.lora.get("lora_dir") or "loras")
|
||||
download_path = pathlib.Path(config.lora.lora_dir)
|
||||
else:
|
||||
download_path = pathlib.Path(config.model.get("model_dir") or "models")
|
||||
download_path = pathlib.Path(config.model.model_dir)
|
||||
|
||||
download_path = download_path / (folder_name or repo_id.split("/")[-1])
|
||||
return download_path
|
||||
|
||||
@@ -2,41 +2,19 @@
|
||||
Functions for logging generation events.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class GenLogPreferences(BaseModel):
|
||||
"""Logging preference config."""
|
||||
|
||||
prompt: bool = False
|
||||
generation_params: bool = False
|
||||
|
||||
|
||||
# Global logging preferences constant
|
||||
PREFERENCES = GenLogPreferences()
|
||||
|
||||
|
||||
def update_from_dict(options_dict: Dict[str, bool]):
|
||||
"""Wrapper to set the logging config for generations"""
|
||||
global PREFERENCES
|
||||
|
||||
# Force bools on the dict
|
||||
for value in options_dict.values():
|
||||
if value is None:
|
||||
value = False
|
||||
|
||||
PREFERENCES = GenLogPreferences.model_validate(options_dict)
|
||||
from common.tabby_config import config
|
||||
|
||||
|
||||
def broadcast_status():
|
||||
"""Broadcasts the current logging status"""
|
||||
enabled = []
|
||||
if PREFERENCES.prompt:
|
||||
if config.logging.log_prompt:
|
||||
enabled.append("prompts")
|
||||
|
||||
if PREFERENCES.generation_params:
|
||||
if config.logging.log_generation_params:
|
||||
enabled.append("generation params")
|
||||
|
||||
if len(enabled) > 0:
|
||||
@@ -47,13 +25,13 @@ def broadcast_status():
|
||||
|
||||
def log_generation_params(**kwargs):
|
||||
"""Logs generation parameters to console."""
|
||||
if PREFERENCES.generation_params:
|
||||
if config.logging.log_generation_params:
|
||||
logger.info(f"Generation options: {kwargs}\n")
|
||||
|
||||
|
||||
def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
|
||||
"""Logs the prompt to console."""
|
||||
if PREFERENCES.prompt:
|
||||
if config.logging.log_prompt:
|
||||
formatted_prompt = "\n" + prompt
|
||||
logger.info(
|
||||
f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n"
|
||||
@@ -66,7 +44,7 @@ def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
|
||||
|
||||
def log_response(request_id: str, response: str):
|
||||
"""Logs the response to console."""
|
||||
if PREFERENCES.prompt:
|
||||
if config.logging.log_prompt:
|
||||
formatted_response = "\n" + response
|
||||
logger.info(
|
||||
f"Response (ID: {request_id}): "
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
class TabbyRequestErrorMessage(BaseModel):
|
||||
@@ -39,7 +38,7 @@ def handle_request_error(message: str, exc_info: bool = True):
|
||||
"""Log a request error to the console."""
|
||||
|
||||
trace = traceback.format_exc()
|
||||
send_trace = unwrap(config.network.get("send_tracebacks"), False)
|
||||
send_trace = config.network.send_tracebacks
|
||||
|
||||
error_message = TabbyRequestErrorMessage(
|
||||
message=message, trace=trace if send_trace else None
|
||||
@@ -134,7 +133,7 @@ def get_global_depends():
|
||||
|
||||
depends = [Depends(add_request_id)]
|
||||
|
||||
if config.logging.get("requests"):
|
||||
if config.logging.log_requests:
|
||||
depends.append(Depends(log_request))
|
||||
|
||||
return depends
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
import yaml
|
||||
from ruamel.yaml import YAML
|
||||
from copy import deepcopy
|
||||
from loguru import logger
|
||||
from pydantic import AliasChoices, BaseModel, Field
|
||||
@@ -416,7 +416,10 @@ async def overrides_from_file(preset_name: str):
|
||||
overrides_container.selected_preset = preset_path.stem
|
||||
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
contents = await raw_preset.read()
|
||||
preset = yaml.safe_load(contents)
|
||||
|
||||
# Create a temporary YAML parser
|
||||
yaml = YAML(typ="safe")
|
||||
preset = yaml.load(contents)
|
||||
overrides_from_dict(preset)
|
||||
|
||||
logger.info("Applied sampler overrides from file.")
|
||||
|
||||
@@ -1,62 +1,94 @@
|
||||
import yaml
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
from inspect import getdoc
|
||||
from os import getenv
|
||||
from textwrap import dedent
|
||||
from typing import Optional
|
||||
|
||||
from common.utils import unwrap, merge_dicts
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.comments import CommentedMap, CommentedSeq
|
||||
from ruamel.yaml.scalarstring import PreservedScalarString
|
||||
|
||||
from common.config_models import BaseConfigModel, TabbyConfigModel
|
||||
from common.utils import merge_dicts, unwrap
|
||||
|
||||
yaml = YAML(typ=["rt", "safe"])
|
||||
|
||||
|
||||
class TabbyConfig:
|
||||
"""Common config class for TabbyAPI. Loaded into sub-dictionaries from YAML file."""
|
||||
|
||||
# Sub-blocks of yaml
|
||||
network: dict = {}
|
||||
logging: dict = {}
|
||||
model: dict = {}
|
||||
draft_model: dict = {}
|
||||
lora: dict = {}
|
||||
sampling: dict = {}
|
||||
developer: dict = {}
|
||||
embeddings: dict = {}
|
||||
|
||||
class TabbyConfig(TabbyConfigModel):
|
||||
# Persistent defaults
|
||||
# TODO: make this pydantic?
|
||||
model_defaults: dict = {}
|
||||
|
||||
def load(self, arguments: Optional[dict] = None):
|
||||
"""Synchronously loads the global application config"""
|
||||
|
||||
# config is applied in order of items in the list
|
||||
configs = [
|
||||
self._from_file(pathlib.Path("config.yml")),
|
||||
self._from_args(unwrap(arguments, {})),
|
||||
]
|
||||
arguments_dict = unwrap(arguments, {})
|
||||
configs = [self._from_environment(), self._from_args(arguments_dict)]
|
||||
|
||||
# If actions aren't present, also look from the file
|
||||
# TODO: Change logic if file loading requires actions in the future
|
||||
if not arguments_dict.get("actions"):
|
||||
configs.insert(0, self._from_file(pathlib.Path("config.yml")))
|
||||
|
||||
merged_config = merge_dicts(*configs)
|
||||
|
||||
self.network = unwrap(merged_config.get("network"), {})
|
||||
self.logging = unwrap(merged_config.get("logging"), {})
|
||||
self.model = unwrap(merged_config.get("model"), {})
|
||||
self.draft_model = unwrap(self.model.get("draft"), {})
|
||||
self.lora = unwrap(self.model.get("lora"), {})
|
||||
self.sampling = unwrap(merged_config.get("sampling"), {})
|
||||
self.developer = unwrap(merged_config.get("developer"), {})
|
||||
self.embeddings = unwrap(merged_config.get("embeddings"), {})
|
||||
# validate and update config
|
||||
merged_config_model = TabbyConfigModel.model_validate(merged_config)
|
||||
for field in TabbyConfigModel.model_fields.keys():
|
||||
value = getattr(merged_config_model, field)
|
||||
setattr(self, field, value)
|
||||
|
||||
# Set model defaults dict once to prevent on-demand reconstruction
|
||||
default_keys = unwrap(self.model.get("use_as_default"), [])
|
||||
for key in default_keys:
|
||||
if key in self.model:
|
||||
self.model_defaults[key] = config.model[key]
|
||||
elif key in self.draft_model:
|
||||
self.model_defaults[key] = config.draft_model[key]
|
||||
# TODO: clean this up a bit
|
||||
for field in self.model.use_as_default:
|
||||
if hasattr(self.model, field):
|
||||
self.model_defaults[field] = getattr(config.model, field)
|
||||
elif hasattr(self.draft_model, field):
|
||||
self.model_defaults[field] = getattr(config.draft_model, field)
|
||||
else:
|
||||
logger.error(
|
||||
f"invalid item {field} in config option `model.use_as_default`"
|
||||
)
|
||||
|
||||
def _from_file(self, config_path: pathlib.Path):
|
||||
"""loads config from a given file path"""
|
||||
|
||||
legacy = False
|
||||
cfg = {}
|
||||
|
||||
# try loading from file
|
||||
try:
|
||||
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
||||
return unwrap(yaml.safe_load(config_file), {})
|
||||
cfg = yaml.load(config_file)
|
||||
|
||||
# NOTE: Remove migration wrapper after a period of time
|
||||
# load legacy config files
|
||||
|
||||
# Model config migration
|
||||
model_cfg = unwrap(cfg.get("model"), {})
|
||||
|
||||
if model_cfg.get("draft"):
|
||||
legacy = True
|
||||
cfg["draft_model"] = model_cfg["draft"]
|
||||
|
||||
if model_cfg.get("lora"):
|
||||
legacy = True
|
||||
cfg["lora"] = model_cfg["lora"]
|
||||
|
||||
# Logging config migration
|
||||
# This will catch the majority of legacy config files
|
||||
logging_cfg = unwrap(cfg.get("logging"), {})
|
||||
unmigrated_log_keys = [
|
||||
key for key in logging_cfg.keys() if not key.startswith("log_")
|
||||
]
|
||||
if unmigrated_log_keys:
|
||||
legacy = True
|
||||
for key in unmigrated_log_keys:
|
||||
cfg["logging"][f"log_{key}"] = cfg["logging"][key]
|
||||
del cfg["logging"][key]
|
||||
except FileNotFoundError:
|
||||
logger.info(f"The '{config_path.name}' file cannot be found")
|
||||
except Exception as exc:
|
||||
@@ -65,25 +97,53 @@ class TabbyConfig:
|
||||
f"the following error:\n\n{exc}"
|
||||
)
|
||||
|
||||
# if no config file was loaded
|
||||
return {}
|
||||
if legacy:
|
||||
logger.warning(
|
||||
"Legacy config.yml file detected. Attempting auto-migration."
|
||||
)
|
||||
|
||||
# Create a temporary base config model
|
||||
new_cfg = TabbyConfigModel.model_validate(cfg)
|
||||
|
||||
try:
|
||||
config_path.rename(f"{config_path}.bak")
|
||||
generate_config_file(model=new_cfg, filename=config_path)
|
||||
logger.info(
|
||||
"Auto-migration successful. "
|
||||
'The old configuration is stored in "config.yml.bak".'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Auto-migration failed because of: {e}\n\n"
|
||||
"Reverted all changes.\n"
|
||||
"Either fix your config.yml and restart or\n"
|
||||
"Delete your old YAML file and create a new "
|
||||
'config by copying "config_sample.yml" to "config.yml".'
|
||||
)
|
||||
|
||||
# Restore the old config
|
||||
config_path.unlink(missing_ok=True)
|
||||
pathlib.Path(f"{config_path}.bak").rename(config_path)
|
||||
|
||||
# Don't use the partially loaded config
|
||||
logger.warning("Starting with no config loaded.")
|
||||
return {}
|
||||
|
||||
return unwrap(cfg, {})
|
||||
|
||||
def _from_args(self, args: dict):
|
||||
"""loads config from the provided arguments"""
|
||||
config = {}
|
||||
|
||||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
config_override = args.get("options", {}).get("config", None)
|
||||
if config_override:
|
||||
logger.info("Config file override detected in args.")
|
||||
config = self._from_file(pathlib.Path(config_override))
|
||||
return config # Return early if loading from file
|
||||
|
||||
for key in ["network", "model", "logging", "developer", "embeddings"]:
|
||||
for key in TabbyConfigModel.model_fields.keys():
|
||||
override = args.get(key)
|
||||
if override:
|
||||
if key == "logging":
|
||||
# Strip the "log_" prefix from logging keys if present
|
||||
override = {k.replace("log_", ""): v for k, v in override.items()}
|
||||
config[key] = override
|
||||
|
||||
return config
|
||||
@@ -91,12 +151,116 @@ class TabbyConfig:
|
||||
def _from_environment(self):
|
||||
"""loads configuration from environment variables"""
|
||||
|
||||
# TODO: load config from environment variables
|
||||
# this means that we can have host default to 0.0.0.0 in docker for example
|
||||
# this would also mean that docker containers no longer require a non
|
||||
# default config file to be used
|
||||
pass
|
||||
config = {}
|
||||
|
||||
for field_name in TabbyConfigModel.model_fields.keys():
|
||||
section_config = {}
|
||||
for sub_field_name in getattr(
|
||||
TabbyConfigModel(), field_name
|
||||
).model_fields.keys():
|
||||
setting = getenv(f"TABBY_{field_name}_{sub_field_name}".upper(), None)
|
||||
if setting is not None:
|
||||
section_config[sub_field_name] = setting
|
||||
|
||||
config[field_name] = section_config
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# Create an empty instance of the config class
|
||||
config: TabbyConfig = TabbyConfig()
|
||||
|
||||
|
||||
def generate_config_file(
|
||||
model: BaseModel = None,
|
||||
filename: str = "config_sample.yml",
|
||||
) -> None:
|
||||
"""Creates a config.yml file from Pydantic models."""
|
||||
|
||||
schema = unwrap(model, TabbyConfigModel())
|
||||
preamble = """
|
||||
# Sample YAML file for configuration.
|
||||
# Comment and uncomment values as needed.
|
||||
# Every value has a default within the application.
|
||||
# This file serves to be a drop in for config.yml
|
||||
|
||||
# Unless specified in the comments, DO NOT put these options in quotes!
|
||||
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
|
||||
"""
|
||||
|
||||
yaml_content = pydantic_model_to_yaml(schema)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write(dedent(preamble).lstrip())
|
||||
yaml.dump(yaml_content, f)
|
||||
|
||||
|
||||
def pydantic_model_to_yaml(model: BaseModel, indentation: int = 0) -> CommentedMap:
|
||||
"""
|
||||
Recursively converts a Pydantic model into a CommentedMap,
|
||||
with descriptions as comments in YAML.
|
||||
"""
|
||||
|
||||
# Create a CommentedMap to hold the output data
|
||||
yaml_data = CommentedMap()
|
||||
|
||||
# Loop through all fields in the model
|
||||
iteration = 1
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
# Get the inner pydantic model
|
||||
value = getattr(model, field_name)
|
||||
|
||||
if isinstance(value, BaseConfigModel):
|
||||
# If the field is another Pydantic model
|
||||
|
||||
if not value._metadata.include_in_config:
|
||||
continue
|
||||
|
||||
yaml_data[field_name] = pydantic_model_to_yaml(
|
||||
value, indentation=indentation + 2
|
||||
)
|
||||
comment = getdoc(value)
|
||||
elif isinstance(value, list) and len(value) > 0:
|
||||
# If the field is a list
|
||||
|
||||
yaml_list = CommentedSeq()
|
||||
if isinstance(value[0], BaseModel):
|
||||
# If the field is a list of Pydantic models
|
||||
# Do not add comments for these items
|
||||
|
||||
for item in value:
|
||||
yaml_list.append(
|
||||
pydantic_model_to_yaml(item, indentation=indentation + 2)
|
||||
)
|
||||
else:
|
||||
# If the field is a normal list, prefer the YAML flow style
|
||||
|
||||
yaml_list.fa.set_flow_style()
|
||||
yaml_list += [
|
||||
PreservedScalarString(element)
|
||||
if isinstance(element, str)
|
||||
else element
|
||||
for element in value
|
||||
]
|
||||
|
||||
yaml_data[field_name] = yaml_list
|
||||
comment = field_info.description
|
||||
else:
|
||||
# Otherwise, just assign the value
|
||||
|
||||
yaml_data[field_name] = value
|
||||
comment = field_info.description
|
||||
|
||||
if comment:
|
||||
# Add a newline to every comment but the first one
|
||||
if iteration != 1:
|
||||
comment = f"\n{comment}"
|
||||
|
||||
yaml_data.yaml_set_comment_before_after_key(
|
||||
field_name, before=comment, indent=indentation
|
||||
)
|
||||
|
||||
# Increment the iteration counter
|
||||
iteration += 1
|
||||
|
||||
return yaml_data
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
"""Common utility functions"""
|
||||
|
||||
from types import NoneType
|
||||
from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
|
||||
|
||||
def unwrap(wrapped, default=None):
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def unwrap(wrapped: Optional[T], default: T = None) -> T:
|
||||
"""Unwrap function for Optionals."""
|
||||
if wrapped is None:
|
||||
return default
|
||||
@@ -14,13 +19,13 @@ def coalesce(*args):
|
||||
return next((arg for arg in args if arg is not None), None)
|
||||
|
||||
|
||||
def prune_dict(input_dict):
|
||||
def prune_dict(input_dict: Dict) -> Dict:
|
||||
"""Trim out instances of None from a dictionary."""
|
||||
|
||||
return {k: v for k, v in input_dict.items() if v is not None}
|
||||
|
||||
|
||||
def merge_dict(dict1, dict2):
|
||||
def merge_dict(dict1: Dict, dict2: Dict) -> Dict:
|
||||
"""Merge 2 dictionaries"""
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
|
||||
@@ -30,7 +35,7 @@ def merge_dict(dict1, dict2):
|
||||
return dict1
|
||||
|
||||
|
||||
def merge_dicts(*dicts):
|
||||
def merge_dicts(*dicts: Dict) -> Dict:
|
||||
"""Merge an arbitrary amount of dictionaries"""
|
||||
result = {}
|
||||
for dictionary in dicts:
|
||||
@@ -43,3 +48,33 @@ def flat_map(input_list):
|
||||
"""Flattens a list of lists into a single list."""
|
||||
|
||||
return [item for sublist in input_list for item in sublist]
|
||||
|
||||
|
||||
def is_list_type(type_hint) -> bool:
|
||||
"""Checks if a type contains a list."""
|
||||
|
||||
if get_origin(type_hint) is list:
|
||||
return True
|
||||
|
||||
# Recursively check for lists inside type arguments
|
||||
type_args = get_args(type_hint)
|
||||
if type_args:
|
||||
return any(is_list_type(arg) for arg in type_args)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def unwrap_optional_type(type_hint) -> Type:
|
||||
"""
|
||||
Unwrap Optional[type] annotations.
|
||||
This is not the same as unwrap.
|
||||
"""
|
||||
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
if NoneType in args:
|
||||
for arg in args:
|
||||
if arg is not NoneType:
|
||||
return arg
|
||||
|
||||
return type_hint
|
||||
|
||||
@@ -1,232 +1,218 @@
|
||||
# Sample YAML file for configuration.
|
||||
# Comment and uncomment values as needed. Every value has a default within the application.
|
||||
# This file serves to be a drop in for config.yml
|
||||
|
||||
# Unless specified in the comments, DO NOT put these options in quotes!
|
||||
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.
|
||||
|
||||
# Options for networking
|
||||
network:
|
||||
# The IP to host on (default: 127.0.0.1).
|
||||
# Use 0.0.0.0 to expose on all network adapters
|
||||
host: 127.0.0.1
|
||||
|
||||
# The port to host on (default: 5000)
|
||||
port: 5000
|
||||
|
||||
# Disable HTTP token authenticaion with requests
|
||||
# WARNING: This will make your instance vulnerable!
|
||||
# Turn on this option if you are ONLY connecting from localhost
|
||||
disable_auth: False
|
||||
|
||||
# Send tracebacks over the API to clients (default: False)
|
||||
# NOTE: Only enable this for debug purposes
|
||||
send_tracebacks: False
|
||||
|
||||
# Select API servers to enable (default: ["OAI"])
|
||||
# Possible values: OAI
|
||||
api_servers: ["OAI"]
|
||||
|
||||
# Options for logging
|
||||
logging:
|
||||
# Enable prompt logging (default: False)
|
||||
prompt: False
|
||||
|
||||
# Enable generation parameter logging (default: False)
|
||||
generation_params: False
|
||||
|
||||
# Enable request logging (default: False)
|
||||
# NOTE: Only use this for debugging!
|
||||
requests: False
|
||||
|
||||
# Options for sampling
|
||||
sampling:
|
||||
# Override preset name. Find this in the sampler-overrides folder (default: None)
|
||||
# This overrides default fallbacks for sampler values that are passed to the API
|
||||
# Server-side overrides are NOT needed by default
|
||||
# WARNING: Using this can result in a generation speed penalty
|
||||
#override_preset:
|
||||
|
||||
# Options for development and experimentation
|
||||
developer:
|
||||
# Skips exllamav2 version check (default: False)
|
||||
# It's highly recommended to update your dependencies rather than enabling this flag
|
||||
# WARNING: Don't set this unless you know what you're doing!
|
||||
#unsafe_launch: False
|
||||
|
||||
# Disable all request streaming (default: False)
|
||||
# A kill switch for turning off SSE in the API server
|
||||
#disable_request_streaming: False
|
||||
|
||||
# Enable the torch CUDA malloc backend (default: False)
|
||||
# This can save a few MBs of VRAM, but has a risk of errors. Use at your own risk.
|
||||
#cuda_malloc_backend: False
|
||||
|
||||
# Enable Uvloop or Winloop (default: False)
|
||||
# Make the program utilize a faster async event loop which can improve performance
|
||||
# NOTE: It's recommended to enable this, but if something breaks, turn this off.
|
||||
#uvloop: False
|
||||
|
||||
# Set process to use a higher priority
|
||||
# For realtime process priority, run as administrator or sudo
|
||||
# Otherwise, the priority will be set to high
|
||||
#realtime_process_priority: False
|
||||
|
||||
# Options for model overrides and loading
|
||||
# Please read the comments to understand how arguments are handled between initial and API loads
|
||||
model:
|
||||
# Overrides the directory to look for models (default: models)
|
||||
# Windows users, DO NOT put this path in quotes! This directory will be invalid otherwise.
|
||||
model_dir: models
|
||||
|
||||
# Sends dummy model names when the models endpoint is queried
|
||||
# Enable this if the program is looking for a specific OAI model
|
||||
#use_dummy_models: False
|
||||
|
||||
# Allow direct loading of models from a completion or chat completion request
|
||||
inline_model_loading: False
|
||||
|
||||
# An initial model to load. Make sure the model is located in the model directory!
|
||||
# A model can be loaded later via the API.
|
||||
# REQUIRED: This must be filled out to load a model on startup!
|
||||
model_name:
|
||||
|
||||
# The below parameters only apply for initial loads
|
||||
# All API based loads do NOT inherit these settings unless specified in use_as_default
|
||||
|
||||
# Names of args to use as a default fallback for API load requests (default: [])
|
||||
# For example, if you always want cache_mode to be Q4 instead of on the inital model load,
|
||||
# Add "cache_mode" to this array
|
||||
# Ex. ["max_seq_len", "cache_mode"]
|
||||
#use_as_default: []
|
||||
|
||||
# The below parameters apply only if model_name is set
|
||||
|
||||
# Max sequence length (default: Empty)
|
||||
# Fetched from the model's base sequence length in config.json by default
|
||||
#max_seq_len:
|
||||
|
||||
# Overrides base model context length (default: Empty)
|
||||
# WARNING: Don't set this unless you know what you're doing!
|
||||
# Again, do NOT use this for configuring context length, use max_seq_len above ^
|
||||
# Only use this if the model's base sequence length in config.json is incorrect (ex. Mistral 7B)
|
||||
#override_base_seq_len:
|
||||
|
||||
# Load model with tensor parallelism
|
||||
# If a GPU split isn't provided, the TP loader will fallback to autosplit
|
||||
# Enabling ignores the gpu_split_auto and autosplit_reserve values
|
||||
#tensor_parallel: False
|
||||
|
||||
# Automatically allocate resources to GPUs (default: True)
|
||||
# NOTE: Not parsed for single GPU users
|
||||
#gpu_split_auto: True
|
||||
|
||||
# Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0)
|
||||
# This is represented as an array of MB per GPU used
|
||||
#autosplit_reserve: [96]
|
||||
|
||||
# An integer array of GBs of vram to split between GPUs (default: [])
|
||||
# Used with tensor parallelism
|
||||
# NOTE: Not parsed for single GPU users
|
||||
#gpu_split: [20.6, 24]
|
||||
|
||||
# Rope scale (default: 1.0)
|
||||
# Same thing as compress_pos_emb
|
||||
# Only use if your model was trained on long context with rope (check config.json)
|
||||
# Leave blank to pull the value from the model
|
||||
#rope_scale: 1.0
|
||||
|
||||
# Rope alpha (default: 1.0)
|
||||
# Same thing as alpha_value
|
||||
# Set to "auto" to automatically calculate
|
||||
# Leave blank to pull the value from the model
|
||||
#rope_alpha: 1.0
|
||||
|
||||
# Enable different cache modes for VRAM savings (slight performance hit).
|
||||
# Possible values FP16, Q8, Q6, Q4. (default: FP16)
|
||||
#cache_mode: FP16
|
||||
|
||||
# Size of the prompt cache to allocate (default: max_seq_len)
|
||||
# This must be a multiple of 256. A larger cache uses more VRAM, but allows for more prompts to be processed at once.
|
||||
# NOTE: Cache size should not be less than max_seq_len.
|
||||
# For CFG, set this to 2 * max_seq_len to make room for both positive and negative prompts.
|
||||
#cache_size:
|
||||
|
||||
# Chunk size for prompt ingestion. A lower value reduces VRAM usage at the cost of ingestion speed (default: 2048)
|
||||
# NOTE: Effects vary depending on the model. An ideal value is between 512 and 4096
|
||||
#chunk_size: 2048
|
||||
|
||||
# Set the maximum amount of prompts to process at one time (default: None/Automatic)
|
||||
# This will be automatically calculated if left blank.
|
||||
# A max batch size of 1 processes prompts one at a time.
|
||||
# NOTE: Only available for Nvidia ampere (30 series) and above GPUs
|
||||
#max_batch_size:
|
||||
|
||||
# Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None)
|
||||
# If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name
|
||||
# of the template you want to use.
|
||||
# NOTE: Only works with chat completion message lists!
|
||||
#prompt_template:
|
||||
|
||||
# Number of experts to use PER TOKEN. Fetched from the model's config.json if not specified (default: Empty)
|
||||
# WARNING: Don't set this unless you know what you're doing!
|
||||
# NOTE: For MoE models (ex. Mixtral) only!
|
||||
#num_experts_per_token:
|
||||
|
||||
# Enables fasttensors to possibly increase model loading speeds (default: False)
|
||||
#fasttensors: true
|
||||
|
||||
# Options for draft models (speculative decoding). This will use more VRAM!
|
||||
#draft:
|
||||
# Overrides the directory to look for draft (default: models)
|
||||
#draft_model_dir: models
|
||||
|
||||
# An initial draft model to load. Make sure this model is located in the model directory!
|
||||
# A draft model can be loaded later via the API.
|
||||
#draft_model_name: A model name
|
||||
|
||||
# The below parameters only apply for initial loads
|
||||
# All API based loads do NOT inherit these settings unless specified in use_as_default
|
||||
|
||||
# Rope scale for draft models (default: 1.0)
|
||||
# Same thing as compress_pos_emb
|
||||
# Only use if your draft model was trained on long context with rope (check config.json)
|
||||
#draft_rope_scale: 1.0
|
||||
|
||||
# Rope alpha for draft model (default: 1.0)
|
||||
# Same thing as alpha_value
|
||||
# Leave blank to automatically calculate alpha value
|
||||
#draft_rope_alpha: 1.0
|
||||
|
||||
# Enable different draft model cache modes for VRAM savings (slight performance hit).
|
||||
# Possible values FP16, Q8, Q6, Q4. (default: FP16)
|
||||
#draft_cache_mode: FP16
|
||||
|
||||
# Options for loras
|
||||
#lora:
|
||||
# Overrides the directory to look for loras (default: loras)
|
||||
#lora_dir: loras
|
||||
|
||||
# List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.
|
||||
#loras:
|
||||
#- name: lora1
|
||||
# scaling: 1.0
|
||||
|
||||
# Options for embedding models and loading.
|
||||
# NOTE: Embeddings requires the "extras" feature to be installed
|
||||
# Install it via "pip install .[extras]"
|
||||
embeddings:
|
||||
# Overrides directory to look for embedding models (default: models)
|
||||
embedding_model_dir: models
|
||||
|
||||
# Device to load embedding models on (default: cpu)
|
||||
# Possible values: cpu, auto, cuda
|
||||
# NOTE: It's recommended to load embedding models on the CPU.
|
||||
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
|
||||
embeddings_device: cpu
|
||||
|
||||
# The below parameters only apply for initial loads
|
||||
# All API based loads do NOT inherit these settings unless specified in use_as_default
|
||||
|
||||
# An initial embedding model to load on the infinity backend (default: None)
|
||||
embedding_model_name:
|
||||
# Sample YAML file for configuration.
|
||||
# Comment and uncomment values as needed.
|
||||
# Every value has a default within the application.
|
||||
# This file serves to be a drop in for config.yml
|
||||
|
||||
# Unless specified in the comments, DO NOT put these options in quotes!
|
||||
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.
|
||||
|
||||
# Options for networking
|
||||
network:
|
||||
# The IP to host on (default: 127.0.0.1).
|
||||
# Use 0.0.0.0 to expose on all network adapters.
|
||||
host: 127.0.0.1
|
||||
|
||||
# The port to host on (default: 5000).
|
||||
port: 5000
|
||||
|
||||
# Disable HTTP token authentication with requests.
|
||||
# WARNING: This will make your instance vulnerable!
|
||||
# Turn on this option if you are ONLY connecting from localhost.
|
||||
disable_auth: false
|
||||
|
||||
# Send tracebacks over the API (default: False).
|
||||
# NOTE: Only enable this for debug purposes.
|
||||
send_tracebacks: false
|
||||
|
||||
# Select API servers to enable (default: ["OAI"]).
|
||||
# Possible values: OAI, Kobold.
|
||||
api_servers: ["OAI"]
|
||||
|
||||
# Options for logging
|
||||
logging:
|
||||
# Enable prompt logging (default: False).
|
||||
log_prompt: false
|
||||
|
||||
# Enable generation parameter logging (default: False).
|
||||
log_generation_params: false
|
||||
|
||||
# Enable request logging (default: False).
|
||||
# NOTE: Only use this for debugging!
|
||||
log_requests: false
|
||||
|
||||
# Options for model overrides and loading
|
||||
# Please read the comments to understand how arguments are handled
|
||||
# between initial and API loads
|
||||
model:
|
||||
# Directory to look for models (default: models).
|
||||
# Windows users, do NOT put this path in quotes!
|
||||
model_dir: models
|
||||
|
||||
# Allow direct loading of models from a completion or chat completion request (default: False).
|
||||
inline_model_loading: false
|
||||
|
||||
# Sends dummy model names when the models endpoint is queried.
|
||||
# Enable this if the client is looking for specific OAI models.
|
||||
use_dummy_models: false
|
||||
|
||||
# An initial model to load.
|
||||
# Make sure the model is located in the model directory!
|
||||
# REQUIRED: This must be filled out to load a model on startup.
|
||||
model_name:
|
||||
|
||||
# Names of args to use as a fallback for API load requests (default: []).
|
||||
# For example, if you always want cache_mode to be Q4 instead of on the inital model load, add "cache_mode" to this array.
|
||||
# Example: ['max_seq_len', 'cache_mode'].
|
||||
use_as_default: []
|
||||
|
||||
# Max sequence length (default: Empty).
|
||||
# Fetched from the model's base sequence length in config.json by default.
|
||||
max_seq_len:
|
||||
|
||||
# Overrides base model context length (default: Empty).
|
||||
# WARNING: Don't set this unless you know what you're doing!
|
||||
# Again, do NOT use this for configuring context length, use max_seq_len above ^
|
||||
override_base_seq_len:
|
||||
|
||||
# Load model with tensor parallelism.
|
||||
# Falls back to autosplit if GPU split isn't provided.
|
||||
# This ignores the gpu_split_auto value.
|
||||
tensor_parallel: false
|
||||
|
||||
# Automatically allocate resources to GPUs (default: True).
|
||||
# Not parsed for single GPU users.
|
||||
gpu_split_auto: true
|
||||
|
||||
# Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).
|
||||
# Represented as an array of MB per GPU.
|
||||
autosplit_reserve: [96]
|
||||
|
||||
# An integer array of GBs of VRAM to split between GPUs (default: []).
|
||||
# Used with tensor parallelism.
|
||||
gpu_split: []
|
||||
|
||||
# Rope scale (default: 1.0).
|
||||
# Same as compress_pos_emb.
|
||||
# Use if the model was trained on long context with rope.
|
||||
# Leave blank to pull the value from the model.
|
||||
rope_scale: 1.0
|
||||
|
||||
# Rope alpha (default: None).
|
||||
# Same as alpha_value. Set to "auto" to auto-calculate.
|
||||
# Leaving this value blank will either pull from the model or auto-calculate.
|
||||
rope_alpha:
|
||||
|
||||
# Enable different cache modes for VRAM savings (default: FP16).
|
||||
# Possible values: 'FP16', 'Q8', 'Q6', 'Q4'.
|
||||
cache_mode: FP16
|
||||
|
||||
# Size of the prompt cache to allocate (default: max_seq_len).
|
||||
# Must be a multiple of 256 and can't be less than max_seq_len.
|
||||
# For CFG, set this to 2 * max_seq_len.
|
||||
cache_size:
|
||||
|
||||
# Chunk size for prompt ingestion (default: 2048).
|
||||
# A lower value reduces VRAM usage but decreases ingestion speed.
|
||||
# NOTE: Effects vary depending on the model.
|
||||
# An ideal value is between 512 and 4096.
|
||||
chunk_size: 2048
|
||||
|
||||
# Set the maximum number of prompts to process at one time (default: None/Automatic).
|
||||
# Automatically calculated if left blank.
|
||||
# NOTE: Only available for Nvidia ampere (30 series) and above GPUs.
|
||||
max_batch_size:
|
||||
|
||||
# Set the prompt template for this model. (default: None)
|
||||
# If empty, attempts to look for the model's chat template.
|
||||
# If a model contains multiple templates in its tokenizer_config.json,
|
||||
# set prompt_template to the name of the template you want to use.
|
||||
# NOTE: Only works with chat completion message lists!
|
||||
prompt_template:
|
||||
|
||||
# Number of experts to use per token.
|
||||
# Fetched from the model's config.json if empty.
|
||||
# NOTE: For MoE models only.
|
||||
# WARNING: Don't set this unless you know what you're doing!
|
||||
num_experts_per_token:
|
||||
|
||||
# Enables fasttensors to possibly increase model loading speeds (default: False).
|
||||
fasttensors: false
|
||||
|
||||
# Options for draft models (speculative decoding)
|
||||
# This will use more VRAM!
|
||||
draft_model:
|
||||
# Directory to look for draft models (default: models)
|
||||
draft_model_dir: models
|
||||
|
||||
# An initial draft model to load.
|
||||
# Ensure the model is in the model directory.
|
||||
draft_model_name:
|
||||
|
||||
# Rope scale for draft models (default: 1.0).
|
||||
# Same as compress_pos_emb.
|
||||
# Use if the draft model was trained on long context with rope.
|
||||
draft_rope_scale: 1.0
|
||||
|
||||
# Rope alpha for draft models (default: None).
|
||||
# Same as alpha_value. Set to "auto" to auto-calculate.
|
||||
# Leaving this value blank will either pull from the model or auto-calculate.
|
||||
draft_rope_alpha:
|
||||
|
||||
# Cache mode for draft models to save VRAM (default: FP16).
|
||||
# Possible values: 'FP16', 'Q8', 'Q6', 'Q4'.
|
||||
draft_cache_mode: FP16
|
||||
|
||||
# Options for Loras
|
||||
lora:
|
||||
# Directory to look for LoRAs (default: loras).
|
||||
lora_dir: loras
|
||||
|
||||
# List of LoRAs to load and associated scaling factors (default scale: 1.0).
|
||||
# For the YAML file, add each entry as a YAML list:
|
||||
# - name: lora1
|
||||
# scaling: 1.0
|
||||
loras:
|
||||
|
||||
# Options for embedding models and loading.
|
||||
# NOTE: Embeddings requires the "extras" feature to be installed
|
||||
# Install it via "pip install .[extras]"
|
||||
embeddings:
|
||||
# Directory to look for embedding models (default: models).
|
||||
embedding_model_dir: models
|
||||
|
||||
# Device to load embedding models on (default: cpu).
|
||||
# Possible values: cpu, auto, cuda.
|
||||
# NOTE: It's recommended to load embedding models on the CPU.
|
||||
# If using an AMD GPU, set this value to 'cuda'.
|
||||
embeddings_device: cpu
|
||||
|
||||
# An initial embedding model to load on the infinity backend.
|
||||
embedding_model_name:
|
||||
|
||||
# Options for Sampling
|
||||
sampling:
|
||||
# Select a sampler override preset (default: None).
|
||||
# Find this in the sampler-overrides folder.
|
||||
# This overrides default fallbacks for sampler values that are passed to the API.
|
||||
override_preset:
|
||||
|
||||
# Options for development and experimentation
|
||||
developer:
|
||||
# Skip Exllamav2 version check (default: False).
|
||||
# WARNING: It's highly recommended to update your dependencies rather than enabling this flag.
|
||||
unsafe_launch: false
|
||||
|
||||
# Disable API request streaming (default: False).
|
||||
disable_request_streaming: false
|
||||
|
||||
# Enable the torch CUDA malloc backend (default: False).
|
||||
cuda_malloc_backend: false
|
||||
|
||||
# Run asyncio using Uvloop or Winloop which can improve performance.
|
||||
# NOTE: It's recommended to enable this, but if something breaks turn this off.
|
||||
uvloop: false
|
||||
|
||||
# Set process to use a higher priority.
|
||||
# For realtime process priority, run as administrator or sudo.
|
||||
# Otherwise, the priority will be set to high.
|
||||
realtime_process_priority: false
|
||||
|
||||
@@ -8,7 +8,6 @@ from common.auth import check_api_key
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionRequest,
|
||||
@@ -71,9 +70,7 @@ async def completion_request(
|
||||
if isinstance(data.prompt, list):
|
||||
data.prompt = "\n".join(data.prompt)
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer.get("disable_request_streaming"), False
|
||||
)
|
||||
disable_request_streaming = config.developer.disable_request_streaming
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
if data.response_format.type == "json":
|
||||
@@ -135,9 +132,7 @@ async def chat_completion_request(
|
||||
if data.response_format.type == "json":
|
||||
data.json_schema = {"type": "object"}
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer.get("disable_request_streaming"), False
|
||||
)
|
||||
disable_request_streaming = config.developer.disable_request_streaming
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
return EventSourceResponse(
|
||||
|
||||
@@ -130,7 +130,7 @@ async def load_inline_model(model_name: str, request: Request):
|
||||
|
||||
raise HTTPException(401, error_message)
|
||||
|
||||
if not unwrap(config.model.get("inline_model_loading"), False):
|
||||
if not config.model.inline_model_loading:
|
||||
logger.warning(
|
||||
f"Unable to switch model to {model_name} because "
|
||||
'"inline_model_loading" is not True in config.yml.'
|
||||
@@ -138,7 +138,7 @@ async def load_inline_model(model_name: str, request: Request):
|
||||
|
||||
return
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(config.model.model_dir)
|
||||
model_path = model_path / model_name
|
||||
|
||||
# Model path doesn't exist
|
||||
|
||||
@@ -62,17 +62,17 @@ async def list_models(request: Request) -> ModelList:
|
||||
Requires an admin key to see all models.
|
||||
"""
|
||||
|
||||
model_dir = unwrap(config.model.get("model_dir"), "models")
|
||||
model_dir = config.model.model_dir
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_model_dir = config.draft_model.get("draft_model_dir")
|
||||
draft_model_dir = config.draft_model.draft_model_dir
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
else:
|
||||
models = await get_current_model_list()
|
||||
|
||||
if unwrap(config.model.get("use_dummy_models"), False):
|
||||
if config.model.use_dummy_models:
|
||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||
|
||||
return models
|
||||
@@ -98,7 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||
draft_model_dir = config.draft_model.draft_model_dir
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
@@ -122,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(config.model.model_dir)
|
||||
model_path = model_path / data.name
|
||||
|
||||
draft_model_path = None
|
||||
@@ -135,7 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||
draft_model_path = config.draft_model.draft_model_dir
|
||||
|
||||
if not model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
@@ -192,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
lora_path = pathlib.Path(config.lora.lora_dir)
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
else:
|
||||
loras = get_active_loras()
|
||||
@@ -227,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
lora_dir = pathlib.Path(config.lora.lora_dir)
|
||||
if not lora_dir.exists():
|
||||
error_message = handle_request_error(
|
||||
"A parent lora directory does not exist for load. Check your config.yml?",
|
||||
@@ -266,9 +266,7 @@ async def list_embedding_models(request: Request) -> ModelList:
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
embedding_model_dir = unwrap(
|
||||
config.embeddings.get("embedding_model_dir"), "models"
|
||||
)
|
||||
embedding_model_dir = config.embeddings.embedding_model_dir
|
||||
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||
|
||||
models = get_model_list(embedding_model_path.resolve())
|
||||
@@ -302,9 +300,7 @@ async def load_embedding_model(
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
embedding_model_dir = pathlib.Path(
|
||||
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir)
|
||||
embedding_model_path = embedding_model_dir / data.name
|
||||
|
||||
if not embedding_model_path.exists():
|
||||
|
||||
@@ -4,9 +4,8 @@ from pydantic import BaseModel, Field, ConfigDict
|
||||
from time import time
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from common.gen_logging import GenLogPreferences
|
||||
from common.config_models import LoggingConfig
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
class ModelCardParameters(BaseModel):
|
||||
@@ -34,7 +33,7 @@ class ModelCard(BaseModel):
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
logging: Optional[GenLogPreferences] = None
|
||||
logging: Optional[LoggingConfig] = None
|
||||
parameters: Optional[ModelCardParameters] = None
|
||||
|
||||
|
||||
@@ -118,11 +117,7 @@ class EmbeddingModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
# Set default from the config
|
||||
embeddings_device: Optional[str] = Field(
|
||||
default_factory=lambda: unwrap(
|
||||
config.embeddings.get("embeddings_device"), "cpu"
|
||||
)
|
||||
)
|
||||
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)
|
||||
|
||||
|
||||
class ModelLoadResponse(BaseModel):
|
||||
|
||||
@@ -2,8 +2,9 @@ import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
|
||||
from common import gen_logging, model
|
||||
from common import model
|
||||
from common.networking import get_generator_error, handle_request_disconnect
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.core.types.model import (
|
||||
ModelCard,
|
||||
@@ -77,7 +78,7 @@ def get_current_model():
|
||||
model_card = ModelCard(
|
||||
id=unwrap(model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(model_params),
|
||||
logging=gen_logging.PREFERENCES,
|
||||
logging=config.logging,
|
||||
)
|
||||
|
||||
if draft_model_params:
|
||||
|
||||
@@ -8,7 +8,6 @@ from loguru import logger
|
||||
from common.logger import UVICORN_LOG_CONFIG
|
||||
from common.networking import get_global_depends
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.Kobold import router as KoboldRouter
|
||||
from endpoints.OAI import router as OAIRouter
|
||||
from endpoints.core.router import router as CoreRouter
|
||||
@@ -36,7 +35,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
api_servers = unwrap(config.network.get("api_servers"), [])
|
||||
api_servers = config.network.api_servers
|
||||
|
||||
# Map for API id to server router
|
||||
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
|
||||
|
||||
65
main.py
65
main.py
@@ -1,7 +1,6 @@
|
||||
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
@@ -12,12 +11,12 @@ from typing import Optional
|
||||
from common import gen_logging, sampling, model
|
||||
from common.args import convert_args_to_dict, init_argparser
|
||||
from common.auth import load_auth_keys
|
||||
from common.actions import branch_to_actions
|
||||
from common.logger import setup_logger
|
||||
from common.networking import is_port_in_use
|
||||
from common.signals import signal_handler
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.server import export_openapi, start_api
|
||||
from endpoints.server import start_api
|
||||
from endpoints.utils import do_export_openapi
|
||||
|
||||
if not do_export_openapi:
|
||||
@@ -27,8 +26,8 @@ if not do_export_openapi:
|
||||
async def entrypoint_async():
|
||||
"""Async entry function for program startup"""
|
||||
|
||||
host = unwrap(config.network.get("host"), "127.0.0.1")
|
||||
port = unwrap(config.network.get("port"), 5000)
|
||||
host = config.network.host
|
||||
port = config.network.port
|
||||
|
||||
# Check if the port is available and attempt to bind a fallback
|
||||
if is_port_in_use(port):
|
||||
@@ -50,16 +49,12 @@ async def entrypoint_async():
|
||||
port = fallback_port
|
||||
|
||||
# Initialize auth keys
|
||||
await load_auth_keys(unwrap(config.network.get("disable_auth"), False))
|
||||
|
||||
# Override the generation log options if given
|
||||
if config.logging:
|
||||
gen_logging.update_from_dict(config.logging)
|
||||
await load_auth_keys(config.network.disable_auth)
|
||||
|
||||
gen_logging.broadcast_status()
|
||||
|
||||
# Set sampler parameter overrides if provided
|
||||
sampling_override_preset = config.sampling.get("override_preset")
|
||||
sampling_override_preset = config.sampling.override_preset
|
||||
if sampling_override_preset:
|
||||
try:
|
||||
await sampling.overrides_from_file(sampling_override_preset)
|
||||
@@ -68,29 +63,38 @@ async def entrypoint_async():
|
||||
|
||||
# If an initial model name is specified, create a container
|
||||
# and load the model
|
||||
model_name = config.model.get("model_name")
|
||||
model_name = config.model.model_name
|
||||
if model_name:
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(config.model.model_dir)
|
||||
model_path = model_path / model_name
|
||||
|
||||
await model.load_model(model_path.resolve(), **config.model)
|
||||
# TODO: remove model_dump()
|
||||
await model.load_model(
|
||||
model_path.resolve(),
|
||||
**config.model.model_dump(),
|
||||
draft=config.draft_model.model_dump(),
|
||||
)
|
||||
|
||||
# Load loras after loading the model
|
||||
if config.lora.get("loras"):
|
||||
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
await model.container.load_loras(lora_dir.resolve(), **config.lora)
|
||||
if config.lora.loras:
|
||||
lora_dir = pathlib.Path(config.lora.lora_dir)
|
||||
# TODO: remove model_dump()
|
||||
await model.container.load_loras(
|
||||
lora_dir.resolve(), **config.lora.model_dump()
|
||||
)
|
||||
|
||||
# If an initial embedding model name is specified, create a separate container
|
||||
# and load the model
|
||||
embedding_model_name = config.embeddings.get("embedding_model_name")
|
||||
embedding_model_name = config.embeddings.embedding_model_name
|
||||
if embedding_model_name:
|
||||
embedding_model_path = pathlib.Path(
|
||||
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = pathlib.Path(config.embeddings.embedding_model_dir)
|
||||
embedding_model_path = embedding_model_path / embedding_model_name
|
||||
|
||||
try:
|
||||
await model.load_embedding_model(embedding_model_path, **config.embeddings)
|
||||
# TODO: remove model_dump()
|
||||
await model.load_embedding_model(
|
||||
embedding_model_path, **config.embeddings.model_dump()
|
||||
)
|
||||
except ImportError as ex:
|
||||
logger.error(ex.msg)
|
||||
|
||||
@@ -112,18 +116,13 @@ def entrypoint(arguments: Optional[dict] = None):
|
||||
# load config
|
||||
config.load(arguments)
|
||||
|
||||
if do_export_openapi:
|
||||
openapi_json = export_openapi()
|
||||
|
||||
with open("openapi.json", "w") as f:
|
||||
f.write(json.dumps(openapi_json))
|
||||
logger.info("Successfully wrote OpenAPI spec to openapi.json")
|
||||
|
||||
# branch to default paths if required
|
||||
if branch_to_actions():
|
||||
return
|
||||
|
||||
# Check exllamav2 version and give a descriptive error if it's too old
|
||||
# Skip if launching unsafely
|
||||
if unwrap(config.developer.get("unsafe_launch"), False):
|
||||
if config.developer.unsafe_launch:
|
||||
logger.warning(
|
||||
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
||||
"If you aren't a developer, please keep this off!"
|
||||
@@ -132,12 +131,12 @@ def entrypoint(arguments: Optional[dict] = None):
|
||||
check_exllama_version()
|
||||
|
||||
# Enable CUDA malloc backend
|
||||
if unwrap(config.developer.get("cuda_malloc_backend"), False):
|
||||
if config.developer.cuda_malloc_backend:
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
||||
logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.")
|
||||
|
||||
# Use Uvloop/Winloop
|
||||
if unwrap(config.developer.get("uvloop"), False):
|
||||
if config.developer.uvloop:
|
||||
if platform.system() == "Windows":
|
||||
from winloop import install
|
||||
else:
|
||||
@@ -149,7 +148,7 @@ def entrypoint(arguments: Optional[dict] = None):
|
||||
logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.")
|
||||
|
||||
# Set the process priority
|
||||
if unwrap(config.developer.get("realtime_process_priority"), False):
|
||||
if config.developer.realtime_process_priority:
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
|
||||
@@ -18,7 +18,7 @@ requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"fastapi-slim >= 0.110.0",
|
||||
"pydantic >= 2.0.0",
|
||||
"PyYAML",
|
||||
"ruamel.yaml",
|
||||
"rich",
|
||||
"uvicorn >= 0.28.1",
|
||||
"jinja2 >= 3.0.0",
|
||||
|
||||
Reference in New Issue
Block a user