mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
Merge pull request #158 from AlpinDale/embeddings
feat: add embeddings support via Infinity-emb
This commit is contained in:
@@ -5,7 +5,7 @@ from sys import maxsize
|
||||
|
||||
from common import config, model
|
||||
from common.auth import check_api_key
|
||||
from common.model import check_model_container
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||
@@ -13,6 +13,7 @@ from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
|
||||
from endpoints.OAI.utils.chat_completion import (
|
||||
format_prompt_with_template,
|
||||
generate_chat_completion,
|
||||
@@ -22,6 +23,7 @@ from endpoints.OAI.utils.completion import (
|
||||
generate_completion,
|
||||
stream_generate_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.embeddings import get_embeddings
|
||||
|
||||
|
||||
api_name = "OAI"
|
||||
@@ -134,3 +136,19 @@ async def chat_completion_request(
|
||||
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# Embeddings endpoint
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
|
||||
)
|
||||
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
embeddings_task = asyncio.create_task(get_embeddings(data, request))
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
embeddings_task,
|
||||
f"Embeddings request {request.state.id} cancelled by user.",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
42
endpoints/OAI/types/embedding.py
Normal file
42
endpoints/OAI/types/embedding.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
input: List[str] = Field(
|
||||
..., description="List of input texts to generate embeddings for."
|
||||
)
|
||||
encoding_format: str = Field(
|
||||
"float",
|
||||
description="Encoding format for the embeddings. "
|
||||
"Can be 'float' or 'base64'.",
|
||||
)
|
||||
model: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the embedding model to use. "
|
||||
"If not provided, the default model will be used.",
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingObject(BaseModel):
|
||||
object: str = Field("embedding", description="Type of the object.")
|
||||
embedding: List[float] = Field(
|
||||
..., description="Embedding values as a list of floats."
|
||||
)
|
||||
index: int = Field(
|
||||
..., description="Index of the input text corresponding to " "the embedding."
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
object: str = Field("list", description="Type of the response object.")
|
||||
data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
|
||||
model: str = Field(..., description="Name of the embedding model used.")
|
||||
usage: UsageInfo = Field(..., description="Information about token usage.")
|
||||
64
endpoints/OAI/utils/embeddings.py
Normal file
64
endpoints/OAI/utils/embeddings.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
This file is derived from
|
||||
[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py)
|
||||
and modified.
|
||||
The changes introduced are: Suppression of progress bar,
|
||||
typing/pydantic classes moved into this file,
|
||||
embeddings function declared async.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from fastapi import Request
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from endpoints.OAI.types.embedding import (
|
||||
EmbeddingObject,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
|
||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||
"""
|
||||
Converts the provided list to a float32 array for OpenAI
|
||||
Ex. float_array = np.array(float_list, dtype="float32")
|
||||
"""
|
||||
|
||||
# Encode raw bytes into base64
|
||||
encoded_bytes = base64.b64encode(float_array.tobytes())
|
||||
|
||||
# Turn raw base64 encoded bytes into ASCII
|
||||
ascii_string = encoded_bytes.decode("ascii")
|
||||
return ascii_string
|
||||
|
||||
|
||||
async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict:
|
||||
model_path = model.embeddings_container.model_dir
|
||||
|
||||
logger.info(f"Recieved embeddings request {request.state.id}")
|
||||
embedding_data = await model.embeddings_container.generate(data.input)
|
||||
|
||||
# OAI expects a return of base64 if the input is base64
|
||||
embedding_object = [
|
||||
EmbeddingObject(
|
||||
embedding=float_list_to_base64(emb)
|
||||
if data.encoding_format == "base64"
|
||||
else emb.tolist(),
|
||||
index=n,
|
||||
)
|
||||
for n, emb in enumerate(embedding_data.get("embeddings"))
|
||||
]
|
||||
|
||||
usage = embedding_data.get("usage")
|
||||
response = EmbeddingsResponse(
|
||||
data=embedding_object,
|
||||
model=model_path.name,
|
||||
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
|
||||
)
|
||||
|
||||
logger.info(f"Finished embeddings request {request.state.id}")
|
||||
|
||||
return response
|
||||
@@ -7,7 +7,7 @@ from sse_starlette import EventSourceResponse
|
||||
from common import config, model, sampling
|
||||
from common.auth import check_admin_key, check_api_key, get_key_permission
|
||||
from common.downloader import hf_repo_download
|
||||
from common.model import check_model_container
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import unwrap
|
||||
@@ -15,6 +15,7 @@ from endpoints.core.types.auth import AuthPermissionResponse
|
||||
from endpoints.core.types.download import DownloadRequest, DownloadResponse
|
||||
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
|
||||
from endpoints.core.types.model import (
|
||||
EmbeddingModelLoadRequest,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
@@ -253,6 +254,93 @@ async def unload_loras():
|
||||
await model.unload_loras()
|
||||
|
||||
|
||||
@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_embedding_models(request: Request) -> ModelList:
|
||||
"""
|
||||
Lists all embedding models in the model directory.
|
||||
|
||||
Requires an admin key to see all embedding models.
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
embedding_model_dir = unwrap(
|
||||
config.embeddings_config().get("embedding_model_dir"), "models"
|
||||
)
|
||||
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||
|
||||
models = get_model_list(embedding_model_path.resolve())
|
||||
else:
|
||||
models = await get_current_model_list(model_type="embedding")
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/model/embedding",
|
||||
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
|
||||
)
|
||||
async def get_embedding_model() -> ModelList:
|
||||
"""Returns the currently loaded embedding model."""
|
||||
|
||||
return get_current_model_list(model_type="embedding")[0]
|
||||
|
||||
|
||||
@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_embedding_model(
|
||||
request: Request, data: EmbeddingModelLoadRequest
|
||||
) -> ModelLoadResponse:
|
||||
# Verify request parameters
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"A model name was not provided for load.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
embedding_model_dir = pathlib.Path(
|
||||
unwrap(config.model_config().get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = embedding_model_dir / data.name
|
||||
|
||||
if not embedding_model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
"Could not find the embedding model path for load. "
|
||||
+ "Check model name or config.yml?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
load_task = asyncio.create_task(
|
||||
model.load_embedding_model(embedding_model_path, **data.model_dump())
|
||||
)
|
||||
await run_with_request_disconnect(
|
||||
request, load_task, "Embedding model load request cancelled by user."
|
||||
)
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
response = ModelLoadResponse(
|
||||
model_type="embedding_model", module=1, modules=1, status="finished"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/model/embedding/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)],
|
||||
)
|
||||
async def unload_embedding_model():
|
||||
"""Unloads the current embedding model."""
|
||||
|
||||
await model.unload_embedding_model()
|
||||
|
||||
|
||||
# Encode tokens endpoint
|
||||
@router.post(
|
||||
"/v1/token/encode",
|
||||
|
||||
@@ -137,6 +137,11 @@ class ModelLoadRequest(BaseModel):
|
||||
skip_queue: Optional[bool] = False
|
||||
|
||||
|
||||
class EmbeddingModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
embeddings_device: Optional[str] = None
|
||||
|
||||
|
||||
class ModelLoadResponse(BaseModel):
|
||||
"""Represents a model load response."""
|
||||
|
||||
|
||||
@@ -32,15 +32,26 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
|
||||
return model_card_list
|
||||
|
||||
|
||||
async def get_current_model_list(is_draft: bool = False):
|
||||
"""Gets the current model in list format and with path only."""
|
||||
async def get_current_model_list(model_type: str = "model"):
|
||||
"""
|
||||
Gets the current model in list format and with path only.
|
||||
|
||||
Unified for fetching both models and embedding models.
|
||||
"""
|
||||
|
||||
current_models = []
|
||||
model_path = None
|
||||
|
||||
# Make sure the model container exists
|
||||
if model.container:
|
||||
model_path = model.container.get_model_path(is_draft)
|
||||
if model_path:
|
||||
current_models.append(ModelCard(id=model_path.name))
|
||||
if model_type == "model" or model_type == "draft":
|
||||
if model.container:
|
||||
model_path = model.container.get_model_path(model_type == "draft")
|
||||
elif model_type == "embedding":
|
||||
if model.embeddings_container:
|
||||
model_path = model.embeddings_container.model_dir
|
||||
|
||||
if model_path:
|
||||
current_models.append(ModelCard(id=model_path.name))
|
||||
|
||||
return ModelList(data=current_models)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user