mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
Merge pull request #158 from AlpinDale/embeddings
feat: add embeddings support via Infinity-emb
This commit is contained in:
@@ -23,6 +23,7 @@ def init_argparser():
|
||||
)
|
||||
add_network_args(parser)
|
||||
add_model_args(parser)
|
||||
add_embeddings_args(parser)
|
||||
add_logging_args(parser)
|
||||
add_developer_args(parser)
|
||||
add_sampling_args(parser)
|
||||
@@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser):
|
||||
sampling_group.add_argument(
|
||||
"--override-preset", type=str, help="Select a sampler override preset"
|
||||
)
|
||||
|
||||
|
||||
def add_embeddings_args(parser: argparse.ArgumentParser):
|
||||
"""Adds arguments specific to embeddings"""
|
||||
|
||||
embeddings_group = parser.add_argument_group("embeddings")
|
||||
embeddings_group.add_argument(
|
||||
"--embedding-model-dir",
|
||||
type=str,
|
||||
help="Overrides the directory to look for models",
|
||||
)
|
||||
embeddings_group.add_argument(
|
||||
"--embedding-model-name", type=str, help="An initial model to load"
|
||||
)
|
||||
embeddings_group.add_argument(
|
||||
"--embeddings-device",
|
||||
type=str,
|
||||
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
|
||||
)
|
||||
|
||||
@@ -59,6 +59,11 @@ def from_args(args: dict):
|
||||
cur_developer_config = developer_config()
|
||||
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
||||
|
||||
embeddings_override = args.get("embeddings")
|
||||
if embeddings_override:
|
||||
cur_embeddings_config = embeddings_config()
|
||||
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
|
||||
|
||||
|
||||
def sampling_config():
|
||||
"""Returns the sampling parameter config from the global config"""
|
||||
@@ -95,3 +100,8 @@ def logging_config():
|
||||
def developer_config():
|
||||
"""Returns the developer specific config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
||||
|
||||
|
||||
def embeddings_config():
|
||||
"""Returns the embeddings config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("embeddings"), {})
|
||||
|
||||
@@ -20,6 +20,15 @@ if not do_export_openapi:
|
||||
|
||||
# Global model container
|
||||
container: Optional[ExllamaV2Container] = None
|
||||
embeddings_container = None
|
||||
|
||||
# Type hint the infinity emb container if it exists
|
||||
from backends.infinity.model import has_infinity_emb
|
||||
|
||||
if has_infinity_emb:
|
||||
from backends.infinity.model import InfinityContainer
|
||||
|
||||
embeddings_container: Optional[InfinityContainer] = None
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
@@ -48,8 +57,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
||||
)
|
||||
|
||||
# Unload the existing model
|
||||
if container and container.model:
|
||||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
@@ -100,6 +107,41 @@ async def unload_loras():
|
||||
await container.unload(loras_only=True)
|
||||
|
||||
|
||||
async def load_embedding_model(model_path: pathlib.Path, **kwargs):
|
||||
global embeddings_container
|
||||
|
||||
# Break out if infinity isn't installed
|
||||
if not has_infinity_emb:
|
||||
raise ImportError(
|
||||
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to install extra packages:\n"
|
||||
"pip install -U .[extras]"
|
||||
)
|
||||
|
||||
# Check if the model is already loaded
|
||||
if embeddings_container and embeddings_container.engine:
|
||||
loaded_model_name = embeddings_container.model_dir.name
|
||||
|
||||
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
|
||||
raise ValueError(
|
||||
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
|
||||
)
|
||||
|
||||
logger.info("Unloading existing embeddings model.")
|
||||
await unload_embedding_model()
|
||||
|
||||
embeddings_container = InfinityContainer(model_path)
|
||||
await embeddings_container.load(**kwargs)
|
||||
|
||||
|
||||
async def unload_embedding_model():
|
||||
global embeddings_container
|
||||
|
||||
await embeddings_container.unload()
|
||||
embeddings_container = None
|
||||
|
||||
|
||||
def get_config_default(key, fallback=None, is_draft=False):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
@@ -126,3 +168,21 @@ async def check_model_container():
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
async def check_embeddings_container():
|
||||
"""
|
||||
FastAPI depends that checks if an embeddings model is loaded.
|
||||
|
||||
This is the same as the model container check, but with embeddings instead.
|
||||
"""
|
||||
|
||||
if embeddings_container is None or not (
|
||||
embeddings_container.model_is_loading or embeddings_container.model_loaded
|
||||
):
|
||||
error_message = handle_request_error(
|
||||
"No embedding models are currently loaded.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
@@ -1,16 +1,32 @@
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
from loguru import logger
|
||||
from types import FrameType
|
||||
|
||||
from common import model
|
||||
|
||||
|
||||
def signal_handler(*_):
|
||||
"""Signal handler for main function. Run before uvicorn starts."""
|
||||
|
||||
logger.warning("Shutdown signal called. Exiting gracefully.")
|
||||
|
||||
# Run async unloads for model
|
||||
asyncio.ensure_future(signal_handler_async())
|
||||
|
||||
# Exit the program
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
async def signal_handler_async(*_):
|
||||
if model.container:
|
||||
await model.container.unload()
|
||||
|
||||
if model.embeddings_container:
|
||||
await model.embeddings_container.unload()
|
||||
|
||||
|
||||
def uvicorn_signal_handler(signal_event: signal.Signals):
|
||||
"""Overrides uvicorn's signal handler."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user