mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-24 16:29:18 +00:00
config is now backed by pydantic (WIP)
- add models for config options - add function to regenerate config.yml - replace references to config with pydantic compatible references - remove unnecessary unwrap() statements TODO: - auto generate env vars - auto generate argparse - test loading a model
This commit is contained in:
248
common/config_models.py
Normal file
248
common/config_models.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
from typing import List, Optional, Union, get_type_hints
|
||||||
|
|
||||||
|
from common.utils import unwrap
|
||||||
|
|
||||||
|
|
||||||
|
class config_config_model(BaseModel):
|
||||||
|
config: Optional[str] = Field(
|
||||||
|
None, description="Path to an overriding config.yml file"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class network_config_model(BaseModel):
|
||||||
|
host: Optional[str] = Field("127.0.0.1", description="The IP to host on")
|
||||||
|
port: Optional[int] = Field(5000, description="The port to host on")
|
||||||
|
disable_auth: Optional[bool] = Field(
|
||||||
|
False, description="Disable HTTP token authentication with requests"
|
||||||
|
)
|
||||||
|
send_tracebacks: Optional[bool] = Field(
|
||||||
|
False, description="Decide whether to send error tracebacks over the API"
|
||||||
|
)
|
||||||
|
api_servers: Optional[List[str]] = Field(
|
||||||
|
[
|
||||||
|
"OAI",
|
||||||
|
],
|
||||||
|
description="API servers to enable. Options: (OAI, Kobold)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class logging_config_model(BaseModel):
|
||||||
|
log_prompt: Optional[bool] = Field(False, description="Enable prompt logging")
|
||||||
|
log_generation_params: Optional[bool] = Field(
|
||||||
|
False, description="Enable generation parameter logging"
|
||||||
|
)
|
||||||
|
log_requests: Optional[bool] = Field(False, description="Enable request logging")
|
||||||
|
|
||||||
|
|
||||||
|
class model_config_model(BaseModel):
|
||||||
|
model_dir: str = Field(
|
||||||
|
"models",
|
||||||
|
description="Overrides the directory to look for models (default: models). Windows users, do NOT put this path in quotes.",
|
||||||
|
)
|
||||||
|
use_dummy_models: Optional[bool] = Field(
|
||||||
|
False,
|
||||||
|
description="Sends dummy model names when the models endpoint is queried. Enable this if looking for specific OAI models.",
|
||||||
|
)
|
||||||
|
model_name: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="An initial model to load. Make sure the model is located in the model directory! REQUIRED: This must be filled out to load a model on startup.",
|
||||||
|
)
|
||||||
|
use_as_default: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Names of args to use as a default fallback for API load requests (default: []). Example: ['max_seq_len', 'cache_mode']",
|
||||||
|
)
|
||||||
|
max_seq_len: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description="Max sequence length. Fetched from the model's base sequence length in config.json by default.",
|
||||||
|
)
|
||||||
|
override_base_seq_len: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description="Overrides base model context length. WARNING: Only use this if the model's base sequence length is incorrect.",
|
||||||
|
)
|
||||||
|
tensor_parallel: Optional[bool] = Field(
|
||||||
|
False,
|
||||||
|
description="Load model with tensor parallelism. Fallback to autosplit if GPU split isn't provided.",
|
||||||
|
)
|
||||||
|
gpu_split_auto: Optional[bool] = Field(
|
||||||
|
True,
|
||||||
|
description="Automatically allocate resources to GPUs (default: True). Not parsed for single GPU users.",
|
||||||
|
)
|
||||||
|
autosplit_reserve: List[int] = Field(
|
||||||
|
[96],
|
||||||
|
description="Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0). Represented as an array of MB per GPU.",
|
||||||
|
)
|
||||||
|
gpu_split: List[float] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="An integer array of GBs of VRAM to split between GPUs (default: []). Used with tensor parallelism.",
|
||||||
|
)
|
||||||
|
rope_scale: Optional[float] = Field(
|
||||||
|
1.0,
|
||||||
|
description="Rope scale (default: 1.0). Same as compress_pos_emb. Only use if the model was trained on long context with rope.",
|
||||||
|
)
|
||||||
|
rope_alpha: Optional[Union[float, str]] = Field(
|
||||||
|
1.0,
|
||||||
|
description="Rope alpha (default: 1.0). Same as alpha_value. Set to 'auto' to auto-calculate.",
|
||||||
|
)
|
||||||
|
cache_mode: Optional[str] = Field(
|
||||||
|
"FP16",
|
||||||
|
description="Enable different cache modes for VRAM savings (default: FP16). Possible values: FP16, Q8, Q6, Q4.",
|
||||||
|
)
|
||||||
|
cache_size: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description="Size of the prompt cache to allocate (default: max_seq_len). Must be a multiple of 256.",
|
||||||
|
)
|
||||||
|
chunk_size: Optional[int] = Field(
|
||||||
|
2048,
|
||||||
|
description="Chunk size for prompt ingestion (default: 2048). A lower value reduces VRAM usage but decreases ingestion speed.",
|
||||||
|
)
|
||||||
|
max_batch_size: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description="Set the maximum number of prompts to process at one time (default: None/Automatic). Automatically calculated if left blank.",
|
||||||
|
)
|
||||||
|
prompt_template: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Set the prompt template for this model. If empty, attempts to look for the model's chat template.",
|
||||||
|
)
|
||||||
|
num_experts_per_token: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description="Number of experts to use per token. Fetched from the model's config.json. For MoE models only.",
|
||||||
|
)
|
||||||
|
fasttensors: Optional[bool] = Field(
|
||||||
|
False,
|
||||||
|
description="Enables fasttensors to possibly increase model loading speeds (default: False).",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class draft_model_config_model(BaseModel):
|
||||||
|
draft_model_dir: Optional[str] = Field(
|
||||||
|
"models",
|
||||||
|
description="Overrides the directory to look for draft models (default: models)",
|
||||||
|
)
|
||||||
|
draft_model_name: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="An initial draft model to load. Ensure the model is in the model directory.",
|
||||||
|
)
|
||||||
|
draft_rope_scale: Optional[float] = Field(
|
||||||
|
1.0,
|
||||||
|
description="Rope scale for draft models (default: 1.0). Same as compress_pos_emb. Use if the draft model was trained on long context with rope.",
|
||||||
|
)
|
||||||
|
draft_rope_alpha: Optional[float] = Field(
|
||||||
|
None,
|
||||||
|
description="Rope alpha for draft models (default: None). Same as alpha_value. Leave blank to auto-calculate the alpha value.",
|
||||||
|
)
|
||||||
|
draft_cache_mode: Optional[str] = Field(
|
||||||
|
"FP16",
|
||||||
|
description="Cache mode for draft models to save VRAM (default: FP16). Possible values: FP16, Q8, Q6, Q4.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class lora_instance_model(BaseModel):
|
||||||
|
name: str = Field(..., description="Name of the LoRA model")
|
||||||
|
scaling: float = Field(
|
||||||
|
1.0, description="Scaling factor for the LoRA model (default: 1.0)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class lora_config_model(BaseModel):
|
||||||
|
lora_dir: Optional[str] = Field(
|
||||||
|
"loras", description="Directory to look for LoRAs (default: 'loras')"
|
||||||
|
)
|
||||||
|
loras: Optional[List[lora_instance_model]] = Field(
|
||||||
|
None,
|
||||||
|
description="List of LoRAs to load and associated scaling factors (default scaling: 1.0)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class sampling_config_model(BaseModel):
|
||||||
|
override_preset: Optional[str] = Field(
|
||||||
|
None, description="Select a sampler override preset"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class developer_config_model(BaseModel):
|
||||||
|
unsafe_launch: Optional[bool] = Field(
|
||||||
|
False, description="Skip Exllamav2 version check"
|
||||||
|
)
|
||||||
|
disable_request_streaming: Optional[bool] = Field(
|
||||||
|
False, description="Disables API request streaming"
|
||||||
|
)
|
||||||
|
cuda_malloc_backend: Optional[bool] = Field(
|
||||||
|
False, description="Runs with the pytorch CUDA malloc backend"
|
||||||
|
)
|
||||||
|
uvloop: Optional[bool] = Field(
|
||||||
|
False, description="Run asyncio using Uvloop or Winloop"
|
||||||
|
)
|
||||||
|
realtime_process_priority: Optional[bool] = Field(
|
||||||
|
False,
|
||||||
|
description="Set process to use a higher priority For realtime process priority, run as administrator or sudo Otherwise, the priority will be set to high",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class embeddings_config_model(BaseModel):
|
||||||
|
embedding_model_dir: Optional[str] = Field(
|
||||||
|
"models",
|
||||||
|
description="Overrides directory to look for embedding models (default: models)",
|
||||||
|
)
|
||||||
|
embeddings_device: Optional[str] = Field(
|
||||||
|
"cpu",
|
||||||
|
description="Device to load embedding models on (default: cpu). Possible values: cpu, auto, cuda. If using an AMD GPU, set this value to 'cuda'.",
|
||||||
|
)
|
||||||
|
embedding_model_name: Optional[str] = Field(
|
||||||
|
None, description="The embeddings model to load"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class tabby_config_model(BaseModel):
|
||||||
|
config: config_config_model = Field(default_factory=config_config_model)
|
||||||
|
network: network_config_model = Field(default_factory=network_config_model)
|
||||||
|
logging: logging_config_model = Field(default_factory=logging_config_model)
|
||||||
|
model: model_config_model = Field(default_factory=model_config_model)
|
||||||
|
draft_model: draft_model_config_model = Field(
|
||||||
|
default_factory=draft_model_config_model
|
||||||
|
)
|
||||||
|
lora: lora_config_model = Field(default_factory=lora_config_model)
|
||||||
|
sampling: sampling_config_model = Field(default_factory=sampling_config_model)
|
||||||
|
developer: developer_config_model = Field(default_factory=developer_config_model)
|
||||||
|
embeddings: embeddings_config_model = Field(default_factory=embeddings_config_model)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
def set_defaults(cls, values):
|
||||||
|
for field_name, field_value in values.items():
|
||||||
|
if field_value is None:
|
||||||
|
default_instance = cls.__annotations__[field_name]().dict()
|
||||||
|
values[field_name] = cls.__annotations__[field_name](**default_instance)
|
||||||
|
return values
|
||||||
|
|
||||||
|
model_config = ConfigDict(validate_assignment=True)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_config_file(filename="config_sample.yml", indentation=2):
|
||||||
|
schema = tabby_config_model.model_json_schema()
|
||||||
|
|
||||||
|
def dump_def(id: str, indent=2):
|
||||||
|
yaml = ""
|
||||||
|
indent = " " * indentation * indent
|
||||||
|
id = id.split("/")[-1]
|
||||||
|
|
||||||
|
section = schema["$defs"][id]["properties"]
|
||||||
|
for property in section.keys(): # get type
|
||||||
|
comment = section[property]["description"]
|
||||||
|
yaml += f"{indent}# {comment}\n"
|
||||||
|
|
||||||
|
value = unwrap(section[property].get("default"), "")
|
||||||
|
yaml += f"{indent}{property}: {value}\n\n"
|
||||||
|
|
||||||
|
return yaml + "\n"
|
||||||
|
|
||||||
|
yaml = ""
|
||||||
|
for section in schema["properties"].keys():
|
||||||
|
yaml += f"{section}:\n"
|
||||||
|
yaml += dump_def(schema["properties"][section]["$ref"])
|
||||||
|
yaml += "\n"
|
||||||
|
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
f.write(yaml)
|
||||||
|
|
||||||
|
|
||||||
|
# generate_config_file("test.yml")
|
||||||
@@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str
|
|||||||
"""Gets the download folder for the repo."""
|
"""Gets the download folder for the repo."""
|
||||||
|
|
||||||
if repo_type == "lora":
|
if repo_type == "lora":
|
||||||
download_path = pathlib.Path(config.lora.get("lora_dir") or "loras")
|
download_path = pathlib.Path(config.lora.lora_dir)
|
||||||
else:
|
else:
|
||||||
download_path = pathlib.Path(config.model.get("model_dir") or "models")
|
download_path = pathlib.Path(config.model.model_dir)
|
||||||
|
|
||||||
download_path = download_path / (folder_name or repo_id.split("/")[-1])
|
download_path = download_path / (folder_name or repo_id.split("/")[-1])
|
||||||
return download_path
|
return download_path
|
||||||
|
|||||||
@@ -6,37 +6,19 @@ from pydantic import BaseModel
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from common.tabby_config import config
|
||||||
class GenLogPreferences(BaseModel):
|
|
||||||
"""Logging preference config."""
|
|
||||||
|
|
||||||
prompt: bool = False
|
|
||||||
generation_params: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
# Global logging preferences constant
|
# Global logging preferences constant
|
||||||
PREFERENCES = GenLogPreferences()
|
PREFERENCES = config.logging
|
||||||
|
|
||||||
|
|
||||||
def update_from_dict(options_dict: Dict[str, bool]):
|
|
||||||
"""Wrapper to set the logging config for generations"""
|
|
||||||
global PREFERENCES
|
|
||||||
|
|
||||||
# Force bools on the dict
|
|
||||||
for value in options_dict.values():
|
|
||||||
if value is None:
|
|
||||||
value = False
|
|
||||||
|
|
||||||
PREFERENCES = GenLogPreferences.model_validate(options_dict)
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_status():
|
def broadcast_status():
|
||||||
"""Broadcasts the current logging status"""
|
"""Broadcasts the current logging status"""
|
||||||
enabled = []
|
enabled = []
|
||||||
if PREFERENCES.prompt:
|
if PREFERENCES.log_prompt:
|
||||||
enabled.append("prompts")
|
enabled.append("prompts")
|
||||||
|
|
||||||
if PREFERENCES.generation_params:
|
if PREFERENCES.log_generation_params:
|
||||||
enabled.append("generation params")
|
enabled.append("generation params")
|
||||||
|
|
||||||
if len(enabled) > 0:
|
if len(enabled) > 0:
|
||||||
@@ -47,13 +29,13 @@ def broadcast_status():
|
|||||||
|
|
||||||
def log_generation_params(**kwargs):
|
def log_generation_params(**kwargs):
|
||||||
"""Logs generation parameters to console."""
|
"""Logs generation parameters to console."""
|
||||||
if PREFERENCES.generation_params:
|
if PREFERENCES.log_generation_params:
|
||||||
logger.info(f"Generation options: {kwargs}\n")
|
logger.info(f"Generation options: {kwargs}\n")
|
||||||
|
|
||||||
|
|
||||||
def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
|
def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
|
||||||
"""Logs the prompt to console."""
|
"""Logs the prompt to console."""
|
||||||
if PREFERENCES.prompt:
|
if PREFERENCES.log_prompt:
|
||||||
formatted_prompt = "\n" + prompt
|
formatted_prompt = "\n" + prompt
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n"
|
f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n"
|
||||||
@@ -66,7 +48,7 @@ def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
|
|||||||
|
|
||||||
def log_response(request_id: str, response: str):
|
def log_response(request_id: str, response: str):
|
||||||
"""Logs the response to console."""
|
"""Logs the response to console."""
|
||||||
if PREFERENCES.prompt:
|
if PREFERENCES.log_prompt:
|
||||||
formatted_response = "\n" + response
|
formatted_response = "\n" + response
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Response (ID: {request_id}): "
|
f"Response (ID: {request_id}): "
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ async def unload_embedding_model():
|
|||||||
def get_config_default(key: str, model_type: str = "model"):
|
def get_config_default(key: str, model_type: str = "model"):
|
||||||
"""Fetches a default value from model config if allowed by the user."""
|
"""Fetches a default value from model config if allowed by the user."""
|
||||||
|
|
||||||
default_keys = unwrap(config.model.get("use_as_default"), [])
|
default_keys = unwrap(config.model.use_as_default, [])
|
||||||
|
|
||||||
# Add extra keys to defaults
|
# Add extra keys to defaults
|
||||||
default_keys.append("embeddings_device")
|
default_keys.append("embeddings_device")
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def handle_request_error(message: str, exc_info: bool = True):
|
|||||||
"""Log a request error to the console."""
|
"""Log a request error to the console."""
|
||||||
|
|
||||||
trace = traceback.format_exc()
|
trace = traceback.format_exc()
|
||||||
send_trace = unwrap(config.network.get("send_tracebacks"), False)
|
send_trace = config.network.send_tracebacks
|
||||||
|
|
||||||
error_message = TabbyRequestErrorMessage(
|
error_message = TabbyRequestErrorMessage(
|
||||||
message=message, trace=trace if send_trace else None
|
message=message, trace=trace if send_trace else None
|
||||||
@@ -134,7 +134,7 @@ def get_global_depends():
|
|||||||
|
|
||||||
depends = [Depends(add_request_id)]
|
depends = [Depends(add_request_id)]
|
||||||
|
|
||||||
if config.logging.get("requests"):
|
if config.logging.log_requests:
|
||||||
depends.append(Depends(log_request))
|
depends.append(Depends(log_request))
|
||||||
|
|
||||||
return depends
|
return depends
|
||||||
|
|||||||
@@ -4,21 +4,11 @@ from loguru import logger
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from common.utils import unwrap, merge_dicts
|
from common.utils import unwrap, merge_dicts
|
||||||
|
from common.config_models import tabby_config_model
|
||||||
|
import common.config_models
|
||||||
|
|
||||||
|
|
||||||
class TabbyConfig:
|
class TabbyConfig(tabby_config_model):
|
||||||
network: dict = {}
|
|
||||||
logging: dict = {}
|
|
||||||
model: dict = {}
|
|
||||||
draft_model: dict = {}
|
|
||||||
lora: dict = {}
|
|
||||||
sampling: dict = {}
|
|
||||||
developer: dict = {}
|
|
||||||
embeddings: dict = {}
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load_config(self, arguments: Optional[dict] = None):
|
def load_config(self, arguments: Optional[dict] = None):
|
||||||
"""load the global application config"""
|
"""load the global application config"""
|
||||||
|
|
||||||
@@ -30,14 +20,11 @@ class TabbyConfig:
|
|||||||
|
|
||||||
merged_config = merge_dicts(*configs)
|
merged_config = merge_dicts(*configs)
|
||||||
|
|
||||||
self.network = unwrap(merged_config.get("network"), {})
|
for field in tabby_config_model.model_fields.keys():
|
||||||
self.logging = unwrap(merged_config.get("logging"), {})
|
value = unwrap(merged_config.get(field), {})
|
||||||
self.model = unwrap(merged_config.get("model"), {})
|
model = getattr(common.config_models, f"{field}_config_model")
|
||||||
self.draft_model = unwrap(merged_config.get("draft"), {})
|
|
||||||
self.lora = unwrap(merged_config.get("draft"), {})
|
setattr(self, field, model.parse_obj(value))
|
||||||
self.sampling = unwrap(merged_config.get("sampling"), {})
|
|
||||||
self.developer = unwrap(merged_config.get("developer"), {})
|
|
||||||
self.embeddings = unwrap(merged_config.get("embeddings"), {})
|
|
||||||
|
|
||||||
def _from_file(self, config_path: pathlib.Path):
|
def _from_file(self, config_path: pathlib.Path):
|
||||||
"""loads config from a given file path"""
|
"""loads config from a given file path"""
|
||||||
|
|||||||
@@ -58,9 +58,7 @@ async def completion_request(
|
|||||||
if isinstance(data.prompt, list):
|
if isinstance(data.prompt, list):
|
||||||
data.prompt = "\n".join(data.prompt)
|
data.prompt = "\n".join(data.prompt)
|
||||||
|
|
||||||
disable_request_streaming = unwrap(
|
disable_request_streaming = config.developer.disable_request_streaming
|
||||||
config.developer.get("disable_request_streaming"), False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set an empty JSON schema if the request wants a JSON response
|
# Set an empty JSON schema if the request wants a JSON response
|
||||||
if data.response_format.type == "json":
|
if data.response_format.type == "json":
|
||||||
@@ -117,9 +115,7 @@ async def chat_completion_request(
|
|||||||
if data.response_format.type == "json":
|
if data.response_format.type == "json":
|
||||||
data.json_schema = {"type": "object"}
|
data.json_schema = {"type": "object"}
|
||||||
|
|
||||||
disable_request_streaming = unwrap(
|
disable_request_streaming = config.developer.disable_request_streaming
|
||||||
config.developer.get("disable_request_streaming"), False
|
|
||||||
)
|
|
||||||
|
|
||||||
if data.stream and not disable_request_streaming:
|
if data.stream and not disable_request_streaming:
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
|
|||||||
@@ -62,17 +62,17 @@ async def list_models(request: Request) -> ModelList:
|
|||||||
Requires an admin key to see all models.
|
Requires an admin key to see all models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_dir = unwrap(config.model.get("model_dir"), "models")
|
model_dir = config.model.model_dir
|
||||||
model_path = pathlib.Path(model_dir)
|
model_path = pathlib.Path(model_dir)
|
||||||
|
|
||||||
draft_model_dir = config.draft_model.get("draft_model_dir")
|
draft_model_dir = config.draft_model.draft_model_dir
|
||||||
|
|
||||||
if get_key_permission(request) == "admin":
|
if get_key_permission(request) == "admin":
|
||||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||||
else:
|
else:
|
||||||
models = await get_current_model_list()
|
models = await get_current_model_list()
|
||||||
|
|
||||||
if unwrap(config.model.get("use_dummy_models"), False):
|
if config.model.use_dummy_models:
|
||||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||||
|
|
||||||
return models
|
return models
|
||||||
@@ -98,7 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if get_key_permission(request) == "admin":
|
if get_key_permission(request) == "admin":
|
||||||
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
draft_model_dir = config.draft_model.draft_model_dir
|
||||||
draft_model_path = pathlib.Path(draft_model_dir)
|
draft_model_path = pathlib.Path(draft_model_dir)
|
||||||
|
|
||||||
models = get_model_list(draft_model_path.resolve())
|
models = get_model_list(draft_model_path.resolve())
|
||||||
@@ -122,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
|||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
model_path = pathlib.Path(config.model.model_dir)
|
||||||
model_path = model_path / data.name
|
model_path = model_path / data.name
|
||||||
|
|
||||||
draft_model_path = None
|
draft_model_path = None
|
||||||
@@ -135,7 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
|||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
draft_model_path = config.draft_model.draft_model_dir
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
@@ -192,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if get_key_permission(request) == "admin":
|
if get_key_permission(request) == "admin":
|
||||||
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
lora_path = pathlib.Path(config.lora.lora_dir)
|
||||||
loras = get_lora_list(lora_path.resolve())
|
loras = get_lora_list(lora_path.resolve())
|
||||||
else:
|
else:
|
||||||
loras = get_active_loras()
|
loras = get_active_loras()
|
||||||
@@ -227,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
|||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
lora_dir = pathlib.Path(config.lora.lora_dir)
|
||||||
if not lora_dir.exists():
|
if not lora_dir.exists():
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"A parent lora directory does not exist for load. Check your config.yml?",
|
"A parent lora directory does not exist for load. Check your config.yml?",
|
||||||
@@ -266,9 +266,7 @@ async def list_embedding_models(request: Request) -> ModelList:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if get_key_permission(request) == "admin":
|
if get_key_permission(request) == "admin":
|
||||||
embedding_model_dir = unwrap(
|
embedding_model_dir = config.embeddings.embedding_model_dir
|
||||||
config.embeddings.get("embedding_model_dir"), "models"
|
|
||||||
)
|
|
||||||
embedding_model_path = pathlib.Path(embedding_model_dir)
|
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||||
|
|
||||||
models = get_model_list(embedding_model_path.resolve())
|
models = get_model_list(embedding_model_path.resolve())
|
||||||
@@ -302,9 +300,7 @@ async def load_embedding_model(
|
|||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
embedding_model_dir = pathlib.Path(
|
embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir)
|
||||||
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
|
||||||
)
|
|
||||||
embedding_model_path = embedding_model_dir / data.name
|
embedding_model_path = embedding_model_dir / data.name
|
||||||
|
|
||||||
if not embedding_model_path.exists():
|
if not embedding_model_path.exists():
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, ConfigDict
|
|||||||
from time import time
|
from time import time
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from common.gen_logging import GenLogPreferences
|
from common.config_models import logging_config_model
|
||||||
from common.model import get_config_default
|
from common.model import get_config_default
|
||||||
|
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ class ModelCard(BaseModel):
|
|||||||
object: str = "model"
|
object: str = "model"
|
||||||
created: int = Field(default_factory=lambda: int(time()))
|
created: int = Field(default_factory=lambda: int(time()))
|
||||||
owned_by: str = "tabbyAPI"
|
owned_by: str = "tabbyAPI"
|
||||||
logging: Optional[GenLogPreferences] = None
|
logging: Optional[logging_config_model] = None
|
||||||
parameters: Optional[ModelCardParameters] = None
|
parameters: Optional[ModelCardParameters] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
api_servers = unwrap(config.network.get("api_servers"), [])
|
api_servers = config.network.api_servers
|
||||||
|
|
||||||
# Map for API id to server router
|
# Map for API id to server router
|
||||||
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
|
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
|
||||||
|
|||||||
34
main.py
34
main.py
@@ -27,8 +27,8 @@ if not do_export_openapi:
|
|||||||
async def entrypoint_async():
|
async def entrypoint_async():
|
||||||
"""Async entry function for program startup"""
|
"""Async entry function for program startup"""
|
||||||
|
|
||||||
host = unwrap(config.network.get("host"), "127.0.0.1")
|
host = config.network.host
|
||||||
port = unwrap(config.network.get("port"), 5000)
|
port = config.network.port
|
||||||
|
|
||||||
# Check if the port is available and attempt to bind a fallback
|
# Check if the port is available and attempt to bind a fallback
|
||||||
if is_port_in_use(port):
|
if is_port_in_use(port):
|
||||||
@@ -50,16 +50,12 @@ async def entrypoint_async():
|
|||||||
port = fallback_port
|
port = fallback_port
|
||||||
|
|
||||||
# Initialize auth keys
|
# Initialize auth keys
|
||||||
load_auth_keys(unwrap(config.network.get("disable_auth"), False))
|
load_auth_keys(config.network.disable_auth)
|
||||||
|
|
||||||
# Override the generation log options if given
|
|
||||||
if config.logging:
|
|
||||||
gen_logging.update_from_dict(config.logging)
|
|
||||||
|
|
||||||
gen_logging.broadcast_status()
|
gen_logging.broadcast_status()
|
||||||
|
|
||||||
# Set sampler parameter overrides if provided
|
# Set sampler parameter overrides if provided
|
||||||
sampling_override_preset = config.sampling.get("override_preset")
|
sampling_override_preset = config.sampling.override_preset
|
||||||
if sampling_override_preset:
|
if sampling_override_preset:
|
||||||
try:
|
try:
|
||||||
sampling.overrides_from_file(sampling_override_preset)
|
sampling.overrides_from_file(sampling_override_preset)
|
||||||
@@ -68,25 +64,23 @@ async def entrypoint_async():
|
|||||||
|
|
||||||
# If an initial model name is specified, create a container
|
# If an initial model name is specified, create a container
|
||||||
# and load the model
|
# and load the model
|
||||||
model_name = config.model.get("model_name")
|
model_name = config.model.model_name
|
||||||
if model_name:
|
if model_name:
|
||||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
model_path = pathlib.Path(config.model.model_dir)
|
||||||
model_path = model_path / model_name
|
model_path = model_path / model_name
|
||||||
|
|
||||||
await model.load_model(model_path.resolve(), **config.model)
|
await model.load_model(model_path.resolve(), **config.model)
|
||||||
|
|
||||||
# Load loras after loading the model
|
# Load loras after loading the model
|
||||||
if config.lora.get("loras"):
|
if config.lora.loras:
|
||||||
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
lora_dir = pathlib.Path(config.lora.lora_dir)
|
||||||
await model.container.load_loras(lora_dir.resolve(), **config.lora)
|
await model.container.load_loras(lora_dir.resolve(), **config.lora)
|
||||||
|
|
||||||
# If an initial embedding model name is specified, create a separate container
|
# If an initial embedding model name is specified, create a separate container
|
||||||
# and load the model
|
# and load the model
|
||||||
embedding_model_name = config.embeddings.get("embedding_model_name")
|
embedding_model_name = config.embeddings.embedding_model_name
|
||||||
if embedding_model_name:
|
if embedding_model_name:
|
||||||
embedding_model_path = pathlib.Path(
|
embedding_model_path = pathlib.Path(config.embeddings.embedding_model_dir)
|
||||||
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
|
||||||
)
|
|
||||||
embedding_model_path = embedding_model_path / embedding_model_name
|
embedding_model_path = embedding_model_path / embedding_model_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -124,7 +118,7 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||||||
# Check exllamav2 version and give a descriptive error if it's too old
|
# Check exllamav2 version and give a descriptive error if it's too old
|
||||||
# Skip if launching unsafely
|
# Skip if launching unsafely
|
||||||
print(f"MAIN.PY {config=}")
|
print(f"MAIN.PY {config=}")
|
||||||
if unwrap(config.developer.get("unsafe_launch"), False):
|
if config.developer.unsafe_launch:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
||||||
"If you aren't a developer, please keep this off!"
|
"If you aren't a developer, please keep this off!"
|
||||||
@@ -133,12 +127,12 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||||||
check_exllama_version()
|
check_exllama_version()
|
||||||
|
|
||||||
# Enable CUDA malloc backend
|
# Enable CUDA malloc backend
|
||||||
if unwrap(config.developer.get("cuda_malloc_backend"), False):
|
if config.developer.cuda_malloc_backend:
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
||||||
logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.")
|
logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.")
|
||||||
|
|
||||||
# Use Uvloop/Winloop
|
# Use Uvloop/Winloop
|
||||||
if unwrap(config.developer.get("uvloop"), False):
|
if config.developer.uvloop:
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
from winloop import install
|
from winloop import install
|
||||||
else:
|
else:
|
||||||
@@ -150,7 +144,7 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||||||
logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.")
|
logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.")
|
||||||
|
|
||||||
# Set the process priority
|
# Set the process priority
|
||||||
if unwrap(config.developer.get("realtime_process_priority"), False):
|
if config.developer.realtime_process_priority:
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
current_process = psutil.Process(os.getpid())
|
current_process = psutil.Process(os.getpid())
|
||||||
|
|||||||
Reference in New Issue
Block a user