mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
config is now backed by pydantic (WIP)
- add models for config options - add function to regenerate config.yml - replace references to config with pydantic compatible references - remove unnecessary unwrap() statements TODO: - auto generate env vars - auto generate argparse - test loading a model
This commit is contained in:
@@ -62,17 +62,17 @@ async def list_models(request: Request) -> ModelList:
|
||||
Requires an admin key to see all models.
|
||||
"""
|
||||
|
||||
model_dir = unwrap(config.model.get("model_dir"), "models")
|
||||
model_dir = config.model.model_dir
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_model_dir = config.draft_model.get("draft_model_dir")
|
||||
draft_model_dir = config.draft_model.draft_model_dir
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
else:
|
||||
models = await get_current_model_list()
|
||||
|
||||
if unwrap(config.model.get("use_dummy_models"), False):
|
||||
if config.model.use_dummy_models:
|
||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||
|
||||
return models
|
||||
@@ -98,7 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||
draft_model_dir = config.draft_model.draft_model_dir
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
@@ -122,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = pathlib.Path(config.model.model_dir)
|
||||
model_path = model_path / data.name
|
||||
|
||||
draft_model_path = None
|
||||
@@ -135,7 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
|
||||
draft_model_path = config.draft_model.draft_model_dir
|
||||
|
||||
if not model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
@@ -192,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
lora_path = pathlib.Path(config.lora.lora_dir)
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
else:
|
||||
loras = get_active_loras()
|
||||
@@ -227,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
|
||||
lora_dir = pathlib.Path(config.lora.lora_dir)
|
||||
if not lora_dir.exists():
|
||||
error_message = handle_request_error(
|
||||
"A parent lora directory does not exist for load. Check your config.yml?",
|
||||
@@ -266,9 +266,7 @@ async def list_embedding_models(request: Request) -> ModelList:
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
embedding_model_dir = unwrap(
|
||||
config.embeddings.get("embedding_model_dir"), "models"
|
||||
)
|
||||
embedding_model_dir = config.embeddings.embedding_model_dir
|
||||
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||
|
||||
models = get_model_list(embedding_model_path.resolve())
|
||||
@@ -302,9 +300,7 @@ async def load_embedding_model(
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
embedding_model_dir = pathlib.Path(
|
||||
unwrap(config.embeddings.get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir)
|
||||
embedding_model_path = embedding_model_dir / data.name
|
||||
|
||||
if not embedding_model_path.exists():
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, ConfigDict
|
||||
from time import time
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from common.gen_logging import GenLogPreferences
|
||||
from common.config_models import logging_config_model
|
||||
from common.model import get_config_default
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class ModelCard(BaseModel):
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
logging: Optional[GenLogPreferences] = None
|
||||
logging: Optional[logging_config_model] = None
|
||||
parameters: Optional[ModelCardParameters] = None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user