mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-26 01:08:52 +00:00
Config: Migrate to global class instead of dicts
The config categories can have defined separation, but preserve the dynamic nature of adding new config options by making all the internal class vars as dictionaries. This was necessary since storing global callbacks stored a state of the previous global_config var that wasn't populated. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -3,10 +3,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
|
||||
from common import config, model
|
||||
from common import model
|
||||
from common.auth import check_api_key
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
@@ -58,7 +59,7 @@ async def completion_request(
|
||||
data.prompt = "\n".join(data.prompt)
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
config.developer.get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
@@ -117,7 +118,7 @@ async def chat_completion_request(
|
||||
data.json_schema = {"type": "object"}
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
config.developer.get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
|
||||
@@ -4,11 +4,12 @@ from sys import maxsize
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from common import config, model, sampling
|
||||
from common import model, sampling
|
||||
from common.auth import check_admin_key, check_api_key, get_key_permission
|
||||
from common.downloader import hf_repo_download
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.tabby_config import config
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import unwrap
|
||||
from endpoints.core.types.auth import AuthPermissionResponse
|
||||
@@ -61,18 +62,17 @@ async def list_models(request: Request) -> ModelList:
|
||||
Requires an admin key to see all models.
|
||||
"""
|
||||
|
||||
model_config = config.model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
model_dir = unwrap(config.model.get("model_dir"), "models")
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_model_dir = config.draft_model_config().get("draft_model_dir")
|
||||
draft_model_dir = config.draft_model.get("draft_model_dir")
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
else:
|
||||
models = await get_current_model_list()
|
||||
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
if unwrap(config.model.get("use_dummy_models"), False):
|
||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||
|
||||
return models
|
||||
@@ -98,9 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
@@ -124,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
draft_model_path = None
|
||||
@@ -137,9 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
draft_model_path = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||
|
||||
if not model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
@@ -196,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
else:
|
||||
loras = get_active_loras()
|
||||
@@ -231,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
error_message = handle_request_error(
|
||||
"A parent lora directory does not exist for load. Check your config.yml?",
|
||||
@@ -271,7 +267,7 @@ async def list_embedding_models(request: Request) -> ModelList:
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
embedding_model_dir = unwrap(
|
||||
config.embeddings_config().get("embedding_model_dir"), "models"
|
||||
config.embeddings.get("embedding_model_dir"), "models"
|
||||
)
|
||||
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||
|
||||
@@ -307,7 +303,7 @@ async def load_embedding_model(
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
embedding_model_dir = pathlib.Path(
|
||||
unwrap(config.model_config().get("embedding_model_dir"), "models")
|
||||
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = embedding_model_dir / data.name
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from common import config
|
||||
from common.logger import UVICORN_LOG_CONFIG
|
||||
from common.networking import get_global_depends
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.Kobold import router as KoboldRouter
|
||||
from endpoints.OAI import router as OAIRouter
|
||||
@@ -36,7 +36,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
api_servers = unwrap(config.network_config().get("api_servers"), [])
|
||||
api_servers = unwrap(config.network.get("api_servers"), [])
|
||||
|
||||
# Map for API id to server router
|
||||
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
|
||||
|
||||
Reference in New Issue
Block a user