mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-27 09:41:54 +00:00
Merge pull request #158 from AlpinDale/embeddings
feat: add embeddings support via Infinity-emb
This commit is contained in:
66
backends/infinity/model.py
Normal file
66
backends/infinity/model.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import gc
|
||||||
|
import pathlib
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from common.utils import unwrap
|
||||||
|
|
||||||
|
# Conditionally import infinity to sidestep its logger
|
||||||
|
# TODO: Make this prettier
|
||||||
|
try:
|
||||||
|
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
||||||
|
|
||||||
|
has_infinity_emb = True
|
||||||
|
except ImportError:
|
||||||
|
has_infinity_emb = False
|
||||||
|
|
||||||
|
|
||||||
|
class InfinityContainer:
|
||||||
|
model_dir: pathlib.Path
|
||||||
|
model_is_loading: bool = False
|
||||||
|
model_loaded: bool = False
|
||||||
|
|
||||||
|
# Conditionally set the type hint based on importablity
|
||||||
|
# TODO: Clean this up
|
||||||
|
if has_infinity_emb:
|
||||||
|
engine: Optional[AsyncEmbeddingEngine] = None
|
||||||
|
else:
|
||||||
|
engine = None
|
||||||
|
|
||||||
|
def __init__(self, model_directory: pathlib.Path):
|
||||||
|
self.model_dir = model_directory
|
||||||
|
|
||||||
|
async def load(self, **kwargs):
|
||||||
|
self.model_is_loading = True
|
||||||
|
|
||||||
|
# Use cpu by default
|
||||||
|
device = unwrap(kwargs.get("embeddings_device"), "cpu")
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model_name_or_path=str(self.model_dir),
|
||||||
|
engine="torch",
|
||||||
|
device=device,
|
||||||
|
bettertransformer=False,
|
||||||
|
model_warmup=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
||||||
|
await self.engine.astart()
|
||||||
|
|
||||||
|
self.model_loaded = True
|
||||||
|
logger.info("Embedding model successfully loaded.")
|
||||||
|
|
||||||
|
async def unload(self):
|
||||||
|
await self.engine.astop()
|
||||||
|
self.engine = None
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
logger.info("Embedding model unloaded.")
|
||||||
|
|
||||||
|
async def generate(self, sentence_input: List[str]):
|
||||||
|
result_embeddings, usage = await self.engine.embed(sentence_input)
|
||||||
|
|
||||||
|
return {"embeddings": result_embeddings, "usage": usage}
|
||||||
@@ -23,6 +23,7 @@ def init_argparser():
|
|||||||
)
|
)
|
||||||
add_network_args(parser)
|
add_network_args(parser)
|
||||||
add_model_args(parser)
|
add_model_args(parser)
|
||||||
|
add_embeddings_args(parser)
|
||||||
add_logging_args(parser)
|
add_logging_args(parser)
|
||||||
add_developer_args(parser)
|
add_developer_args(parser)
|
||||||
add_sampling_args(parser)
|
add_sampling_args(parser)
|
||||||
@@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser):
|
|||||||
sampling_group.add_argument(
|
sampling_group.add_argument(
|
||||||
"--override-preset", type=str, help="Select a sampler override preset"
|
"--override-preset", type=str, help="Select a sampler override preset"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_embeddings_args(parser: argparse.ArgumentParser):
|
||||||
|
"""Adds arguments specific to embeddings"""
|
||||||
|
|
||||||
|
embeddings_group = parser.add_argument_group("embeddings")
|
||||||
|
embeddings_group.add_argument(
|
||||||
|
"--embedding-model-dir",
|
||||||
|
type=str,
|
||||||
|
help="Overrides the directory to look for models",
|
||||||
|
)
|
||||||
|
embeddings_group.add_argument(
|
||||||
|
"--embedding-model-name", type=str, help="An initial model to load"
|
||||||
|
)
|
||||||
|
embeddings_group.add_argument(
|
||||||
|
"--embeddings-device",
|
||||||
|
type=str,
|
||||||
|
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
|
||||||
|
)
|
||||||
|
|||||||
@@ -59,6 +59,11 @@ def from_args(args: dict):
|
|||||||
cur_developer_config = developer_config()
|
cur_developer_config = developer_config()
|
||||||
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
||||||
|
|
||||||
|
embeddings_override = args.get("embeddings")
|
||||||
|
if embeddings_override:
|
||||||
|
cur_embeddings_config = embeddings_config()
|
||||||
|
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
|
||||||
|
|
||||||
|
|
||||||
def sampling_config():
|
def sampling_config():
|
||||||
"""Returns the sampling parameter config from the global config"""
|
"""Returns the sampling parameter config from the global config"""
|
||||||
@@ -95,3 +100,8 @@ def logging_config():
|
|||||||
def developer_config():
|
def developer_config():
|
||||||
"""Returns the developer specific config from the global config"""
|
"""Returns the developer specific config from the global config"""
|
||||||
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
||||||
|
|
||||||
|
|
||||||
|
def embeddings_config():
|
||||||
|
"""Returns the embeddings config from the global config"""
|
||||||
|
return unwrap(GLOBAL_CONFIG.get("embeddings"), {})
|
||||||
|
|||||||
@@ -20,6 +20,15 @@ if not do_export_openapi:
|
|||||||
|
|
||||||
# Global model container
|
# Global model container
|
||||||
container: Optional[ExllamaV2Container] = None
|
container: Optional[ExllamaV2Container] = None
|
||||||
|
embeddings_container = None
|
||||||
|
|
||||||
|
# Type hint the infinity emb container if it exists
|
||||||
|
from backends.infinity.model import has_infinity_emb
|
||||||
|
|
||||||
|
if has_infinity_emb:
|
||||||
|
from backends.infinity.model import InfinityContainer
|
||||||
|
|
||||||
|
embeddings_container: Optional[InfinityContainer] = None
|
||||||
|
|
||||||
|
|
||||||
def load_progress(module, modules):
|
def load_progress(module, modules):
|
||||||
@@ -48,8 +57,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||||||
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Unload the existing model
|
|
||||||
if container and container.model:
|
|
||||||
logger.info("Unloading existing model.")
|
logger.info("Unloading existing model.")
|
||||||
await unload_model()
|
await unload_model()
|
||||||
|
|
||||||
@@ -100,6 +107,41 @@ async def unload_loras():
|
|||||||
await container.unload(loras_only=True)
|
await container.unload(loras_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_embedding_model(model_path: pathlib.Path, **kwargs):
|
||||||
|
global embeddings_container
|
||||||
|
|
||||||
|
# Break out if infinity isn't installed
|
||||||
|
if not has_infinity_emb:
|
||||||
|
raise ImportError(
|
||||||
|
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||||
|
"Please run the following command in your environment "
|
||||||
|
"to install extra packages:\n"
|
||||||
|
"pip install -U .[extras]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the model is already loaded
|
||||||
|
if embeddings_container and embeddings_container.engine:
|
||||||
|
loaded_model_name = embeddings_container.model_dir.name
|
||||||
|
|
||||||
|
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
|
||||||
|
raise ValueError(
|
||||||
|
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Unloading existing embeddings model.")
|
||||||
|
await unload_embedding_model()
|
||||||
|
|
||||||
|
embeddings_container = InfinityContainer(model_path)
|
||||||
|
await embeddings_container.load(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def unload_embedding_model():
|
||||||
|
global embeddings_container
|
||||||
|
|
||||||
|
await embeddings_container.unload()
|
||||||
|
embeddings_container = None
|
||||||
|
|
||||||
|
|
||||||
def get_config_default(key, fallback=None, is_draft=False):
|
def get_config_default(key, fallback=None, is_draft=False):
|
||||||
"""Fetches a default value from model config if allowed by the user."""
|
"""Fetches a default value from model config if allowed by the user."""
|
||||||
|
|
||||||
@@ -126,3 +168,21 @@ async def check_model_container():
|
|||||||
).error.message
|
).error.message
|
||||||
|
|
||||||
raise HTTPException(400, error_message)
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_embeddings_container():
|
||||||
|
"""
|
||||||
|
FastAPI depends that checks if an embeddings model is loaded.
|
||||||
|
|
||||||
|
This is the same as the model container check, but with embeddings instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if embeddings_container is None or not (
|
||||||
|
embeddings_container.model_is_loading or embeddings_container.model_loaded
|
||||||
|
):
|
||||||
|
error_message = handle_request_error(
|
||||||
|
"No embedding models are currently loaded.",
|
||||||
|
exc_info=False,
|
||||||
|
).error.message
|
||||||
|
|
||||||
|
raise HTTPException(400, error_message)
|
||||||
|
|||||||
@@ -1,16 +1,32 @@
|
|||||||
|
import asyncio
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from types import FrameType
|
from types import FrameType
|
||||||
|
|
||||||
|
from common import model
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(*_):
|
def signal_handler(*_):
|
||||||
"""Signal handler for main function. Run before uvicorn starts."""
|
"""Signal handler for main function. Run before uvicorn starts."""
|
||||||
|
|
||||||
logger.warning("Shutdown signal called. Exiting gracefully.")
|
logger.warning("Shutdown signal called. Exiting gracefully.")
|
||||||
|
|
||||||
|
# Run async unloads for model
|
||||||
|
asyncio.ensure_future(signal_handler_async())
|
||||||
|
|
||||||
|
# Exit the program
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
async def signal_handler_async(*_):
|
||||||
|
if model.container:
|
||||||
|
await model.container.unload()
|
||||||
|
|
||||||
|
if model.embeddings_container:
|
||||||
|
await model.embeddings_container.unload()
|
||||||
|
|
||||||
|
|
||||||
def uvicorn_signal_handler(signal_event: signal.Signals):
|
def uvicorn_signal_handler(signal_event: signal.Signals):
|
||||||
"""Overrides uvicorn's signal handler."""
|
"""Overrides uvicorn's signal handler."""
|
||||||
|
|
||||||
|
|||||||
@@ -201,3 +201,19 @@ model:
|
|||||||
#loras:
|
#loras:
|
||||||
#- name: lora1
|
#- name: lora1
|
||||||
# scaling: 1.0
|
# scaling: 1.0
|
||||||
|
|
||||||
|
# Options for embedding models and loading.
|
||||||
|
# NOTE: Embeddings requires the "extras" feature to be installed
|
||||||
|
# Install it via "pip install .[extras]"
|
||||||
|
embeddings:
|
||||||
|
# Overrides directory to look for embedding models (default: models)
|
||||||
|
embedding_model_dir: models
|
||||||
|
|
||||||
|
# An initial embedding model to load on the infinity backend (default: None)
|
||||||
|
embedding_model_name:
|
||||||
|
|
||||||
|
# Device to load embedding models on (default: cpu)
|
||||||
|
# Possible values: cpu, auto, cuda
|
||||||
|
# NOTE: It's recommended to load embedding models on the CPU.
|
||||||
|
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
|
||||||
|
embeddings_device: cpu
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from sys import maxsize
|
|||||||
|
|
||||||
from common import config, model
|
from common import config, model
|
||||||
from common.auth import check_api_key
|
from common.auth import check_api_key
|
||||||
from common.model import check_model_container
|
from common.model import check_embeddings_container, check_model_container
|
||||||
from common.networking import handle_request_error, run_with_request_disconnect
|
from common.networking import handle_request_error, run_with_request_disconnect
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||||
@@ -13,6 +13,7 @@ from endpoints.OAI.types.chat_completion import (
|
|||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
)
|
)
|
||||||
|
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
|
||||||
from endpoints.OAI.utils.chat_completion import (
|
from endpoints.OAI.utils.chat_completion import (
|
||||||
format_prompt_with_template,
|
format_prompt_with_template,
|
||||||
generate_chat_completion,
|
generate_chat_completion,
|
||||||
@@ -22,6 +23,7 @@ from endpoints.OAI.utils.completion import (
|
|||||||
generate_completion,
|
generate_completion,
|
||||||
stream_generate_completion,
|
stream_generate_completion,
|
||||||
)
|
)
|
||||||
|
from endpoints.OAI.utils.embeddings import get_embeddings
|
||||||
|
|
||||||
|
|
||||||
api_name = "OAI"
|
api_name = "OAI"
|
||||||
@@ -134,3 +136,19 @@ async def chat_completion_request(
|
|||||||
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# Embeddings endpoint
|
||||||
|
@router.post(
|
||||||
|
"/v1/embeddings",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
|
||||||
|
)
|
||||||
|
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||||
|
embeddings_task = asyncio.create_task(get_embeddings(data, request))
|
||||||
|
response = await run_with_request_disconnect(
|
||||||
|
request,
|
||||||
|
embeddings_task,
|
||||||
|
f"Embeddings request {request.state.id} cancelled by user.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|||||||
42
endpoints/OAI/types/embedding.py
Normal file
42
endpoints/OAI/types/embedding.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(BaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequest(BaseModel):
|
||||||
|
input: List[str] = Field(
|
||||||
|
..., description="List of input texts to generate embeddings for."
|
||||||
|
)
|
||||||
|
encoding_format: str = Field(
|
||||||
|
"float",
|
||||||
|
description="Encoding format for the embeddings. "
|
||||||
|
"Can be 'float' or 'base64'.",
|
||||||
|
)
|
||||||
|
model: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Name of the embedding model to use. "
|
||||||
|
"If not provided, the default model will be used.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingObject(BaseModel):
|
||||||
|
object: str = Field("embedding", description="Type of the object.")
|
||||||
|
embedding: List[float] = Field(
|
||||||
|
..., description="Embedding values as a list of floats."
|
||||||
|
)
|
||||||
|
index: int = Field(
|
||||||
|
..., description="Index of the input text corresponding to " "the embedding."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsResponse(BaseModel):
|
||||||
|
object: str = Field("list", description="Type of the response object.")
|
||||||
|
data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
|
||||||
|
model: str = Field(..., description="Name of the embedding model used.")
|
||||||
|
usage: UsageInfo = Field(..., description="Information about token usage.")
|
||||||
64
endpoints/OAI/utils/embeddings.py
Normal file
64
endpoints/OAI/utils/embeddings.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
This file is derived from
|
||||||
|
[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py)
|
||||||
|
and modified.
|
||||||
|
The changes introduced are: Suppression of progress bar,
|
||||||
|
typing/pydantic classes moved into this file,
|
||||||
|
embeddings function declared async.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from fastapi import Request
|
||||||
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from common import model
|
||||||
|
from endpoints.OAI.types.embedding import (
|
||||||
|
EmbeddingObject,
|
||||||
|
EmbeddingsRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||||
|
"""
|
||||||
|
Converts the provided list to a float32 array for OpenAI
|
||||||
|
Ex. float_array = np.array(float_list, dtype="float32")
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Encode raw bytes into base64
|
||||||
|
encoded_bytes = base64.b64encode(float_array.tobytes())
|
||||||
|
|
||||||
|
# Turn raw base64 encoded bytes into ASCII
|
||||||
|
ascii_string = encoded_bytes.decode("ascii")
|
||||||
|
return ascii_string
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict:
|
||||||
|
model_path = model.embeddings_container.model_dir
|
||||||
|
|
||||||
|
logger.info(f"Recieved embeddings request {request.state.id}")
|
||||||
|
embedding_data = await model.embeddings_container.generate(data.input)
|
||||||
|
|
||||||
|
# OAI expects a return of base64 if the input is base64
|
||||||
|
embedding_object = [
|
||||||
|
EmbeddingObject(
|
||||||
|
embedding=float_list_to_base64(emb)
|
||||||
|
if data.encoding_format == "base64"
|
||||||
|
else emb.tolist(),
|
||||||
|
index=n,
|
||||||
|
)
|
||||||
|
for n, emb in enumerate(embedding_data.get("embeddings"))
|
||||||
|
]
|
||||||
|
|
||||||
|
usage = embedding_data.get("usage")
|
||||||
|
response = EmbeddingsResponse(
|
||||||
|
data=embedding_object,
|
||||||
|
model=model_path.name,
|
||||||
|
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Finished embeddings request {request.state.id}")
|
||||||
|
|
||||||
|
return response
|
||||||
@@ -7,7 +7,7 @@ from sse_starlette import EventSourceResponse
|
|||||||
from common import config, model, sampling
|
from common import config, model, sampling
|
||||||
from common.auth import check_admin_key, check_api_key, get_key_permission
|
from common.auth import check_admin_key, check_api_key, get_key_permission
|
||||||
from common.downloader import hf_repo_download
|
from common.downloader import hf_repo_download
|
||||||
from common.model import check_model_container
|
from common.model import check_embeddings_container, check_model_container
|
||||||
from common.networking import handle_request_error, run_with_request_disconnect
|
from common.networking import handle_request_error, run_with_request_disconnect
|
||||||
from common.templating import PromptTemplate, get_all_templates
|
from common.templating import PromptTemplate, get_all_templates
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
@@ -15,6 +15,7 @@ from endpoints.core.types.auth import AuthPermissionResponse
|
|||||||
from endpoints.core.types.download import DownloadRequest, DownloadResponse
|
from endpoints.core.types.download import DownloadRequest, DownloadResponse
|
||||||
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
|
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
|
||||||
from endpoints.core.types.model import (
|
from endpoints.core.types.model import (
|
||||||
|
EmbeddingModelLoadRequest,
|
||||||
ModelCard,
|
ModelCard,
|
||||||
ModelList,
|
ModelList,
|
||||||
ModelLoadRequest,
|
ModelLoadRequest,
|
||||||
@@ -253,6 +254,93 @@ async def unload_loras():
|
|||||||
await model.unload_loras()
|
await model.unload_loras()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)])
|
||||||
|
async def list_embedding_models(request: Request) -> ModelList:
|
||||||
|
"""
|
||||||
|
Lists all embedding models in the model directory.
|
||||||
|
|
||||||
|
Requires an admin key to see all embedding models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if get_key_permission(request) == "admin":
|
||||||
|
embedding_model_dir = unwrap(
|
||||||
|
config.embeddings_config().get("embedding_model_dir"), "models"
|
||||||
|
)
|
||||||
|
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||||
|
|
||||||
|
models = get_model_list(embedding_model_path.resolve())
|
||||||
|
else:
|
||||||
|
models = await get_current_model_list(model_type="embedding")
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/v1/model/embedding",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
|
||||||
|
)
|
||||||
|
async def get_embedding_model() -> ModelList:
|
||||||
|
"""Returns the currently loaded embedding model."""
|
||||||
|
|
||||||
|
return get_current_model_list(model_type="embedding")[0]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)])
|
||||||
|
async def load_embedding_model(
|
||||||
|
request: Request, data: EmbeddingModelLoadRequest
|
||||||
|
) -> ModelLoadResponse:
|
||||||
|
# Verify request parameters
|
||||||
|
if not data.name:
|
||||||
|
error_message = handle_request_error(
|
||||||
|
"A model name was not provided for load.",
|
||||||
|
exc_info=False,
|
||||||
|
).error.message
|
||||||
|
|
||||||
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
|
embedding_model_dir = pathlib.Path(
|
||||||
|
unwrap(config.model_config().get("embedding_model_dir"), "models")
|
||||||
|
)
|
||||||
|
embedding_model_path = embedding_model_dir / data.name
|
||||||
|
|
||||||
|
if not embedding_model_path.exists():
|
||||||
|
error_message = handle_request_error(
|
||||||
|
"Could not find the embedding model path for load. "
|
||||||
|
+ "Check model name or config.yml?",
|
||||||
|
exc_info=False,
|
||||||
|
).error.message
|
||||||
|
|
||||||
|
raise HTTPException(400, error_message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
load_task = asyncio.create_task(
|
||||||
|
model.load_embedding_model(embedding_model_path, **data.model_dump())
|
||||||
|
)
|
||||||
|
await run_with_request_disconnect(
|
||||||
|
request, load_task, "Embedding model load request cancelled by user."
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
error_message = handle_request_error(str(exc)).error.message
|
||||||
|
|
||||||
|
raise HTTPException(400, error_message) from exc
|
||||||
|
|
||||||
|
response = ModelLoadResponse(
|
||||||
|
model_type="embedding_model", module=1, modules=1, status="finished"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/model/embedding/unload",
|
||||||
|
dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)],
|
||||||
|
)
|
||||||
|
async def unload_embedding_model():
|
||||||
|
"""Unloads the current embedding model."""
|
||||||
|
|
||||||
|
await model.unload_embedding_model()
|
||||||
|
|
||||||
|
|
||||||
# Encode tokens endpoint
|
# Encode tokens endpoint
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/token/encode",
|
"/v1/token/encode",
|
||||||
|
|||||||
@@ -137,6 +137,11 @@ class ModelLoadRequest(BaseModel):
|
|||||||
skip_queue: Optional[bool] = False
|
skip_queue: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModelLoadRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
embeddings_device: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ModelLoadResponse(BaseModel):
|
class ModelLoadResponse(BaseModel):
|
||||||
"""Represents a model load response."""
|
"""Represents a model load response."""
|
||||||
|
|
||||||
|
|||||||
@@ -32,15 +32,26 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
|
|||||||
return model_card_list
|
return model_card_list
|
||||||
|
|
||||||
|
|
||||||
async def get_current_model_list(is_draft: bool = False):
|
async def get_current_model_list(model_type: str = "model"):
|
||||||
"""Gets the current model in list format and with path only."""
|
"""
|
||||||
|
Gets the current model in list format and with path only.
|
||||||
|
|
||||||
|
Unified for fetching both models and embedding models.
|
||||||
|
"""
|
||||||
|
|
||||||
current_models = []
|
current_models = []
|
||||||
|
model_path = None
|
||||||
|
|
||||||
# Make sure the model container exists
|
# Make sure the model container exists
|
||||||
if model.container:
|
if model_type == "model" or model_type == "draft":
|
||||||
model_path = model.container.get_model_path(is_draft)
|
if model.container:
|
||||||
if model_path:
|
model_path = model.container.get_model_path(model_type == "draft")
|
||||||
current_models.append(ModelCard(id=model_path.name))
|
elif model_type == "embedding":
|
||||||
|
if model.embeddings_container:
|
||||||
|
model_path = model.embeddings_container.model_dir
|
||||||
|
|
||||||
|
if model_path:
|
||||||
|
current_models.append(ModelCard(id=model_path.name))
|
||||||
|
|
||||||
return ModelList(data=current_models)
|
return ModelList(data=current_models)
|
||||||
|
|
||||||
|
|||||||
15
main.py
15
main.py
@@ -87,6 +87,21 @@ async def entrypoint_async():
|
|||||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||||
await model.container.load_loras(lora_dir.resolve(), **lora_config)
|
await model.container.load_loras(lora_dir.resolve(), **lora_config)
|
||||||
|
|
||||||
|
# If an initial embedding model name is specified, create a separate container
|
||||||
|
# and load the model
|
||||||
|
embedding_config = config.embeddings_config()
|
||||||
|
embedding_model_name = embedding_config.get("embedding_model_name")
|
||||||
|
if embedding_model_name:
|
||||||
|
embedding_model_path = pathlib.Path(
|
||||||
|
unwrap(embedding_config.get("embedding_model_dir"), "models")
|
||||||
|
)
|
||||||
|
embedding_model_path = embedding_model_path / embedding_model_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
await model.load_embedding_model(embedding_model_path, **embedding_config)
|
||||||
|
except ImportError as ex:
|
||||||
|
logger.error(ex.msg)
|
||||||
|
|
||||||
await start_api(host, port)
|
await start_api(host, port)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,8 @@ dependencies = [
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
extras = [
|
extras = [
|
||||||
# Heavy dependencies that aren't for everyday use
|
# Heavy dependencies that aren't for everyday use
|
||||||
"outlines"
|
"outlines",
|
||||||
|
"sentence-transformers"
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"ruff == 0.3.2"
|
"ruff == 0.3.2"
|
||||||
|
|||||||
Reference in New Issue
Block a user