Config: Migrate to global class instead of dicts

The config categories can have defined separation, but preserve
the dynamic nature of adding new config options by making all the
internal class vars as dictionaries.

This was necessary since storing global callbacks stored a state
of the previous global_config var that wasn't populated.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-09-04 23:13:36 -04:00
parent e772fa2981
commit 93872b34d7
10 changed files with 149 additions and 153 deletions

View File

@@ -4,11 +4,12 @@ from sys import maxsize
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from common import config, model, sampling
from common import model, sampling
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
from endpoints.core.types.auth import AuthPermissionResponse
@@ -61,18 +62,17 @@ async def list_models(request: Request) -> ModelList:
Requires an admin key to see all models.
"""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_dir = unwrap(config.model.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_model_dir = config.draft_model_config().get("draft_model_dir")
draft_model_dir = config.draft_model.get("draft_model_dir")
if get_key_permission(request) == "admin":
models = get_model_list(model_path.resolve(), draft_model_dir)
else:
models = await get_current_model_list()
if unwrap(model_config.get("use_dummy_models"), False):
if unwrap(config.model.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
@@ -98,9 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
"""
if get_key_permission(request) == "admin":
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
@@ -124,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
model_path = model_path / data.name
draft_model_path = None
@@ -137,9 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
draft_model_path = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
if not model_path.exists():
error_message = handle_request_error(
@@ -196,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
"""
if get_key_permission(request) == "admin":
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
else:
loras = get_active_loras()
@@ -231,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
raise HTTPException(400, error_message)
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
if not lora_dir.exists():
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
@@ -271,7 +267,7 @@ async def list_embedding_models(request: Request) -> ModelList:
if get_key_permission(request) == "admin":
embedding_model_dir = unwrap(
config.embeddings_config().get("embedding_model_dir"), "models"
config.embeddings.get("embedding_model_dir"), "models"
)
embedding_model_path = pathlib.Path(embedding_model_dir)
@@ -307,7 +303,7 @@ async def load_embedding_model(
raise HTTPException(400, error_message)
embedding_model_dir = pathlib.Path(
unwrap(config.model_config().get("embedding_model_dir"), "models")
unwrap(config.embeddings.get("embedding_model_dir"), "models")
)
embedding_model_path = embedding_model_dir / data.name