diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 27fc9e5..4c9bb69 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -18,6 +18,8 @@ except ImportError: class InfinityContainer: model_dir: pathlib.Path + model_is_loading: bool = False + model_loaded: bool = False # Conditionally set the type hint based on importablity # TODO: Clean this up @@ -30,6 +32,8 @@ class InfinityContainer: self.model_dir = model_directory async def load(self, **kwargs): + self.model_is_loading = True + # Use cpu by default device = unwrap(kwargs.get("device"), "cpu") @@ -44,6 +48,9 @@ class InfinityContainer: self.engine = AsyncEmbeddingEngine.from_args(engine_args) await self.engine.astart() + self.model_loaded = True + logger.info("Embedding model successfully loaded.") + async def unload(self): await self.engine.astop() self.engine = None diff --git a/common/model.py b/common/model.py index b4b259e..3776ff9 100644 --- a/common/model.py +++ b/common/model.py @@ -162,9 +162,15 @@ async def check_model_container(): async def check_embeddings_container(): - """FastAPI depends that checks if an embeddings model is loaded.""" + """ + FastAPI depends that checks if an embeddings model is loaded. - if embeddings_container is None: + This is the same as the model container check, but with embeddings instead. + """ + + if embeddings_container is None or not ( + embeddings_container.model_is_loading or embeddings_container.model_loaded + ): error_message = handle_request_error( "No embeddings models are currently loaded.", exc_info=False, diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index b702e52..b428c00 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -5,7 +5,7 @@ from sys import maxsize from common import config, model from common.auth import check_api_key -from common.model import check_model_container +from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.utils import unwrap from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse @@ -132,7 +132,7 @@ async def chat_completion_request( # Embeddings endpoint @router.post( "/v1/embeddings", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], ) async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: embeddings_task = asyncio.create_task(get_embeddings(data, request))