Config: Migrate to global class instead of dicts

The config categories can have defined separation, but preserve
the dynamic nature of adding new config options by making all the
internal class vars as dictionaries.

This was necessary since storing global callbacks stored a state
of the previous global_config var that wasn't populated.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-09-04 23:13:36 -04:00
parent e772fa2981
commit 93872b34d7
10 changed files with 149 additions and 153 deletions

View File

@@ -3,10 +3,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from common import config, model
from common import model
from common.auth import check_api_key
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config
from common.utils import unwrap
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
@@ -58,7 +59,7 @@ async def completion_request(
data.prompt = "\n".join(data.prompt)
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
config.developer.get("disable_request_streaming"), False
)
# Set an empty JSON schema if the request wants a JSON response
@@ -117,7 +118,7 @@ async def chat_completion_request(
data.json_schema = {"type": "object"}
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
config.developer.get("disable_request_streaming"), False
)
if data.stream and not disable_request_streaming:

View File

@@ -4,11 +4,12 @@ from sys import maxsize
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from common import config, model, sampling
from common import model, sampling
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
from endpoints.core.types.auth import AuthPermissionResponse
@@ -61,18 +62,17 @@ async def list_models(request: Request) -> ModelList:
Requires an admin key to see all models.
"""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_dir = unwrap(config.model.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_model_dir = config.draft_model_config().get("draft_model_dir")
draft_model_dir = config.draft_model.get("draft_model_dir")
if get_key_permission(request) == "admin":
models = get_model_list(model_path.resolve(), draft_model_dir)
else:
models = await get_current_model_list()
if unwrap(model_config.get("use_dummy_models"), False):
if unwrap(config.model.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
@@ -98,9 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
"""
if get_key_permission(request) == "admin":
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
@@ -124,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
model_path = model_path / data.name
draft_model_path = None
@@ -137,9 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
draft_model_path = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
if not model_path.exists():
error_message = handle_request_error(
@@ -196,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
"""
if get_key_permission(request) == "admin":
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
else:
loras = get_active_loras()
@@ -231,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
raise HTTPException(400, error_message)
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
if not lora_dir.exists():
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
@@ -271,7 +267,7 @@ async def list_embedding_models(request: Request) -> ModelList:
if get_key_permission(request) == "admin":
embedding_model_dir = unwrap(
config.embeddings_config().get("embedding_model_dir"), "models"
config.embeddings.get("embedding_model_dir"), "models"
)
embedding_model_path = pathlib.Path(embedding_model_dir)
@@ -307,7 +303,7 @@ async def load_embedding_model(
raise HTTPException(400, error_message)
embedding_model_dir = pathlib.Path(
unwrap(config.model_config().get("embedding_model_dir"), "models")
unwrap(config.embeddings.get("embedding_model_dir"), "models")
)
embedding_model_path = embedding_model_dir / data.name

View File

@@ -5,9 +5,9 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from common import config
from common.logger import UVICORN_LOG_CONFIG
from common.networking import get_global_depends
from common.tabby_config import config
from common.utils import unwrap
from endpoints.Kobold import router as KoboldRouter
from endpoints.OAI import router as OAIRouter
@@ -36,7 +36,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
allow_headers=["*"],
)
api_servers = unwrap(config.network_config().get("api_servers"), [])
api_servers = unwrap(config.network.get("api_servers"), [])
# Map for API id to server router
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}