mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-29 10:42:03 +00:00
Config: Migrate to global class instead of dicts
The config categories can have defined separation, but preserve the dynamic nature of adding new config options by making all the internal class vars as dictionaries. This was necessary since storing global callbacks stored a state of the previous global_config var that wasn't populated. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -4,11 +4,12 @@ from sys import maxsize
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from common import config, model, sampling
|
||||
from common import model, sampling
|
||||
from common.auth import check_admin_key, check_api_key, get_key_permission
|
||||
from common.downloader import hf_repo_download
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.tabby_config import config
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import unwrap
|
||||
from endpoints.core.types.auth import AuthPermissionResponse
|
||||
@@ -61,18 +62,17 @@ async def list_models(request: Request) -> ModelList:
|
||||
Requires an admin key to see all models.
|
||||
"""
|
||||
|
||||
model_config = config.model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
model_dir = unwrap(config.model.get("model_dir"), "models")
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_model_dir = config.draft_model_config().get("draft_model_dir")
|
||||
draft_model_dir = config.draft_model.get("draft_model_dir")
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
else:
|
||||
models = await get_current_model_list()
|
||||
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
if unwrap(config.model.get("use_dummy_models"), False):
|
||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||
|
||||
return models
|
||||
@@ -98,9 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
@@ -124,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
draft_model_path = None
|
||||
@@ -137,9 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
draft_model_path = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||
|
||||
if not model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
@@ -196,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
else:
|
||||
loras = get_active_loras()
|
||||
@@ -231,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
error_message = handle_request_error(
|
||||
"A parent lora directory does not exist for load. Check your config.yml?",
|
||||
@@ -271,7 +267,7 @@ async def list_embedding_models(request: Request) -> ModelList:
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
embedding_model_dir = unwrap(
|
||||
config.embeddings_config().get("embedding_model_dir"), "models"
|
||||
config.embeddings.get("embedding_model_dir"), "models"
|
||||
)
|
||||
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||
|
||||
@@ -307,7 +303,7 @@ async def load_embedding_model(
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
embedding_model_dir = pathlib.Path(
|
||||
unwrap(config.model_config().get("embedding_model_dir"), "models")
|
||||
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = embedding_model_dir / data.name
|
||||
|
||||
|
||||
Reference in New Issue
Block a user