Merge pull request #185 from SecretiveShell/refactor-config-loading

Refactor config loading
This commit is contained in:
Brian Dashore
2024-09-05 18:00:32 -04:00
committed by GitHub
10 changed files with 160 additions and 176 deletions

View File

@@ -3,10 +3,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from common import config, model
from common import model
from common.auth import check_api_key
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config
from common.utils import unwrap
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
@@ -64,7 +65,7 @@ async def completion_request(
data.prompt = "\n".join(data.prompt)
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
config.developer.get("disable_request_streaming"), False
)
# Set an empty JSON schema if the request wants a JSON response
@@ -128,7 +129,7 @@ async def chat_completion_request(
data.json_schema = {"type": "object"}
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
config.developer.get("disable_request_streaming"), False
)
if data.stream and not disable_request_streaming:

View File

@@ -4,11 +4,12 @@ from sys import maxsize
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from common import config, model, sampling
from common import model, sampling
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
from endpoints.core.types.auth import AuthPermissionResponse
@@ -61,18 +62,17 @@ async def list_models(request: Request) -> ModelList:
Requires an admin key to see all models.
"""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_dir = unwrap(config.model.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_model_dir = config.draft_model_config().get("draft_model_dir")
draft_model_dir = config.draft_model.get("draft_model_dir")
if get_key_permission(request) == "admin":
models = get_model_list(model_path.resolve(), draft_model_dir)
else:
models = await get_current_model_list()
if unwrap(model_config.get("use_dummy_models"), False):
if unwrap(config.model.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
@@ -98,9 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
"""
if get_key_permission(request) == "admin":
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
@@ -124,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
model_path = model_path / data.name
draft_model_path = None
@@ -137,9 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
draft_model_path = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
if not model_path.exists():
error_message = handle_request_error(
@@ -196,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
"""
if get_key_permission(request) == "admin":
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
else:
loras = get_active_loras()
@@ -231,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
raise HTTPException(400, error_message)
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
if not lora_dir.exists():
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
@@ -271,7 +267,7 @@ async def list_embedding_models(request: Request) -> ModelList:
if get_key_permission(request) == "admin":
embedding_model_dir = unwrap(
config.embeddings_config().get("embedding_model_dir"), "models"
config.embeddings.get("embedding_model_dir"), "models"
)
embedding_model_path = pathlib.Path(embedding_model_dir)
@@ -307,7 +303,7 @@ async def load_embedding_model(
raise HTTPException(400, error_message)
embedding_model_dir = pathlib.Path(
unwrap(config.model_config().get("embedding_model_dir"), "models")
unwrap(config.embeddings.get("embedding_model_dir"), "models")
)
embedding_model_path = embedding_model_dir / data.name

View File

@@ -5,9 +5,9 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from common import config
from common.logger import UVICORN_LOG_CONFIG
from common.networking import get_global_depends
from common.tabby_config import config
from common.utils import unwrap
from endpoints.Kobold import router as KoboldRouter
from endpoints.OAI import router as OAIRouter
@@ -36,7 +36,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
allow_headers=["*"],
)
api_servers = unwrap(config.network_config().get("api_servers"), [])
api_servers = unwrap(config.network.get("api_servers"), [])
# Map for API id to server router
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}