mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Embeddings: Add model load checks
Same as the normal model container. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -18,6 +18,8 @@ except ImportError:
|
||||
|
||||
class InfinityContainer:
|
||||
model_dir: pathlib.Path
|
||||
model_is_loading: bool = False
|
||||
model_loaded: bool = False
|
||||
|
||||
# Conditionally set the type hint based on importablity
|
||||
# TODO: Clean this up
|
||||
@@ -30,6 +32,8 @@ class InfinityContainer:
|
||||
self.model_dir = model_directory
|
||||
|
||||
async def load(self, **kwargs):
|
||||
self.model_is_loading = True
|
||||
|
||||
# Use cpu by default
|
||||
device = unwrap(kwargs.get("device"), "cpu")
|
||||
|
||||
@@ -44,6 +48,9 @@ class InfinityContainer:
|
||||
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
||||
await self.engine.astart()
|
||||
|
||||
self.model_loaded = True
|
||||
logger.info("Embedding model successfully loaded.")
|
||||
|
||||
async def unload(self):
|
||||
await self.engine.astop()
|
||||
self.engine = None
|
||||
|
||||
Reference in New Issue
Block a user