mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-29 18:51:53 +00:00
Embeddings: Add model load checks
Same as the normal model container. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -18,6 +18,8 @@ except ImportError:
|
|||||||
|
|
||||||
class InfinityContainer:
|
class InfinityContainer:
|
||||||
model_dir: pathlib.Path
|
model_dir: pathlib.Path
|
||||||
|
model_is_loading: bool = False
|
||||||
|
model_loaded: bool = False
|
||||||
|
|
||||||
# Conditionally set the type hint based on importablity
|
# Conditionally set the type hint based on importablity
|
||||||
# TODO: Clean this up
|
# TODO: Clean this up
|
||||||
@@ -30,6 +32,8 @@ class InfinityContainer:
|
|||||||
self.model_dir = model_directory
|
self.model_dir = model_directory
|
||||||
|
|
||||||
async def load(self, **kwargs):
|
async def load(self, **kwargs):
|
||||||
|
self.model_is_loading = True
|
||||||
|
|
||||||
# Use cpu by default
|
# Use cpu by default
|
||||||
device = unwrap(kwargs.get("device"), "cpu")
|
device = unwrap(kwargs.get("device"), "cpu")
|
||||||
|
|
||||||
@@ -44,6 +48,9 @@ class InfinityContainer:
|
|||||||
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
||||||
await self.engine.astart()
|
await self.engine.astart()
|
||||||
|
|
||||||
|
self.model_loaded = True
|
||||||
|
logger.info("Embedding model successfully loaded.")
|
||||||
|
|
||||||
async def unload(self):
|
async def unload(self):
|
||||||
await self.engine.astop()
|
await self.engine.astop()
|
||||||
self.engine = None
|
self.engine = None
|
||||||
|
|||||||
@@ -162,9 +162,15 @@ async def check_model_container():
|
|||||||
|
|
||||||
|
|
||||||
async def check_embeddings_container():
|
async def check_embeddings_container():
|
||||||
"""FastAPI depends that checks if an embeddings model is loaded."""
|
"""
|
||||||
|
FastAPI depends that checks if an embeddings model is loaded.
|
||||||
|
|
||||||
if embeddings_container is None:
|
This is the same as the model container check, but with embeddings instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if embeddings_container is None or not (
|
||||||
|
embeddings_container.model_is_loading or embeddings_container.model_loaded
|
||||||
|
):
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"No embeddings models are currently loaded.",
|
"No embeddings models are currently loaded.",
|
||||||
exc_info=False,
|
exc_info=False,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from sys import maxsize
|
|||||||
|
|
||||||
from common import config, model
|
from common import config, model
|
||||||
from common.auth import check_api_key
|
from common.auth import check_api_key
|
||||||
from common.model import check_model_container
|
from common.model import check_embeddings_container, check_model_container
|
||||||
from common.networking import handle_request_error, run_with_request_disconnect
|
from common.networking import handle_request_error, run_with_request_disconnect
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||||
@@ -132,7 +132,7 @@ async def chat_completion_request(
|
|||||||
# Embeddings endpoint
|
# Embeddings endpoint
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/embeddings",
|
"/v1/embeddings",
|
||||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
|
||||||
)
|
)
|
||||||
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
|
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||||
embeddings_task = asyncio.create_task(get_embeddings(data, request))
|
embeddings_task = asyncio.create_task(get_embeddings(data, request))
|
||||||
|
|||||||
Reference in New Issue
Block a user