Merge pull request #158 from AlpinDale/embeddings

feat: add embeddings support via Infinity-emb
This commit is contained in:
Brian Dashore
2024-07-31 20:33:12 -04:00
committed by GitHub
14 changed files with 443 additions and 11 deletions

View File

@@ -5,7 +5,7 @@ from sys import maxsize
from common import config, model
from common.auth import check_api_key
from common.model import check_model_container
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.utils import unwrap
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
@@ -13,6 +13,7 @@ from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
@@ -22,6 +23,7 @@ from endpoints.OAI.utils.completion import (
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import get_embeddings
api_name = "OAI"
@@ -134,3 +136,19 @@ async def chat_completion_request(
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response
# Embeddings endpoint
@router.post(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
)
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
embeddings_task = asyncio.create_task(get_embeddings(data, request))
response = await run_with_request_disconnect(
request,
embeddings_task,
f"Embeddings request {request.state.id} cancelled by user.",
)
return response

View File

@@ -0,0 +1,42 @@
from typing import List, Optional
from pydantic import BaseModel, Field
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class EmbeddingsRequest(BaseModel):
input: List[str] = Field(
..., description="List of input texts to generate embeddings for."
)
encoding_format: str = Field(
"float",
description="Encoding format for the embeddings. "
"Can be 'float' or 'base64'.",
)
model: Optional[str] = Field(
None,
description="Name of the embedding model to use. "
"If not provided, the default model will be used.",
)
class EmbeddingObject(BaseModel):
object: str = Field("embedding", description="Type of the object.")
embedding: List[float] = Field(
..., description="Embedding values as a list of floats."
)
index: int = Field(
..., description="Index of the input text corresponding to " "the embedding."
)
class EmbeddingsResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
model: str = Field(..., description="Name of the embedding model used.")
usage: UsageInfo = Field(..., description="Information about token usage.")

View File

@@ -0,0 +1,64 @@
"""
This file is derived from
[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py)
and modified.
The changes introduced are: Suppression of progress bar,
typing/pydantic classes moved into this file,
embeddings function declared async.
"""
import base64
from fastapi import Request
import numpy as np
from loguru import logger
from common import model
from endpoints.OAI.types.embedding import (
EmbeddingObject,
EmbeddingsRequest,
EmbeddingsResponse,
UsageInfo,
)
def float_list_to_base64(float_array: np.ndarray) -> str:
"""
Converts the provided list to a float32 array for OpenAI
Ex. float_array = np.array(float_list, dtype="float32")
"""
# Encode raw bytes into base64
encoded_bytes = base64.b64encode(float_array.tobytes())
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode("ascii")
return ascii_string
async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict:
model_path = model.embeddings_container.model_dir
logger.info(f"Recieved embeddings request {request.state.id}")
embedding_data = await model.embeddings_container.generate(data.input)
# OAI expects a return of base64 if the input is base64
embedding_object = [
EmbeddingObject(
embedding=float_list_to_base64(emb)
if data.encoding_format == "base64"
else emb.tolist(),
index=n,
)
for n, emb in enumerate(embedding_data.get("embeddings"))
]
usage = embedding_data.get("usage")
response = EmbeddingsResponse(
data=embedding_object,
model=model_path.name,
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
)
logger.info(f"Finished embeddings request {request.state.id}")
return response

View File

@@ -7,7 +7,7 @@ from sse_starlette import EventSourceResponse
from common import config, model, sampling
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common.model import check_model_container
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
@@ -15,6 +15,7 @@ from endpoints.core.types.auth import AuthPermissionResponse
from endpoints.core.types.download import DownloadRequest, DownloadResponse
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
from endpoints.core.types.model import (
EmbeddingModelLoadRequest,
ModelCard,
ModelList,
ModelLoadRequest,
@@ -253,6 +254,93 @@ async def unload_loras():
await model.unload_loras()
@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)])
async def list_embedding_models(request: Request) -> ModelList:
"""
Lists all embedding models in the model directory.
Requires an admin key to see all embedding models.
"""
if get_key_permission(request) == "admin":
embedding_model_dir = unwrap(
config.embeddings_config().get("embedding_model_dir"), "models"
)
embedding_model_path = pathlib.Path(embedding_model_dir)
models = get_model_list(embedding_model_path.resolve())
else:
models = await get_current_model_list(model_type="embedding")
return models
@router.get(
"/v1/model/embedding",
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
)
async def get_embedding_model() -> ModelList:
"""Returns the currently loaded embedding model."""
return get_current_model_list(model_type="embedding")[0]
@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)])
async def load_embedding_model(
request: Request, data: EmbeddingModelLoadRequest
) -> ModelLoadResponse:
# Verify request parameters
if not data.name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
embedding_model_dir = pathlib.Path(
unwrap(config.model_config().get("embedding_model_dir"), "models")
)
embedding_model_path = embedding_model_dir / data.name
if not embedding_model_path.exists():
error_message = handle_request_error(
"Could not find the embedding model path for load. "
+ "Check model name or config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
try:
load_task = asyncio.create_task(
model.load_embedding_model(embedding_model_path, **data.model_dump())
)
await run_with_request_disconnect(
request, load_task, "Embedding model load request cancelled by user."
)
except Exception as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
response = ModelLoadResponse(
model_type="embedding_model", module=1, modules=1, status="finished"
)
return response
@router.post(
"/v1/model/embedding/unload",
dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)],
)
async def unload_embedding_model():
"""Unloads the current embedding model."""
await model.unload_embedding_model()
# Encode tokens endpoint
@router.post(
"/v1/token/encode",

View File

@@ -137,6 +137,11 @@ class ModelLoadRequest(BaseModel):
skip_queue: Optional[bool] = False
class EmbeddingModelLoadRequest(BaseModel):
name: str
embeddings_device: Optional[str] = None
class ModelLoadResponse(BaseModel):
"""Represents a model load response."""

View File

@@ -32,15 +32,26 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
return model_card_list
async def get_current_model_list(is_draft: bool = False):
"""Gets the current model in list format and with path only."""
async def get_current_model_list(model_type: str = "model"):
"""
Gets the current model in list format and with path only.
Unified for fetching both models and embedding models.
"""
current_models = []
model_path = None
# Make sure the model container exists
if model.container:
model_path = model.container.get_model_path(is_draft)
if model_path:
current_models.append(ModelCard(id=model_path.name))
if model_type == "model" or model_type == "draft":
if model.container:
model_path = model.container.get_model_path(model_type == "draft")
elif model_type == "embedding":
if model.embeddings_container:
model_path = model.embeddings_container.model_dir
if model_path:
current_models.append(ModelCard(id=model_path.name))
return ModelList(data=current_models)