ruff: formatting

This commit is contained in:
AlpinDale
2024-07-26 02:53:14 +00:00
parent 765d3593b3
commit 5adfab1cbd
3 changed files with 52 additions and 55 deletions

View File

@@ -16,24 +16,22 @@ embeddings_params_initialized = False
def initialize_embedding_params(): def initialize_embedding_params():
''' """
using 'lazy loading' to avoid circular import using 'lazy loading' to avoid circular import
so this function will be executed only once so this function will be executed only once
''' """
global embeddings_params_initialized global embeddings_params_initialized
if not embeddings_params_initialized: if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device global st_model, embeddings_model, embeddings_device
st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", "all-mpnet-base-v2")
'all-mpnet-base-v2')
embeddings_model = None embeddings_model = None
# OPENAI_EMBEDDING_DEVICE: auto (best or cpu), # OPENAI_EMBEDDING_DEVICE: auto (best or cpu),
# cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, # cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep,
# hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, # hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta,
# hpu, mtia, privateuseone # hpu, mtia, privateuseone
embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", 'cpu') embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", "cpu")
if embeddings_device.lower() == 'auto': if embeddings_device.lower() == "auto":
embeddings_device = None embeddings_device = None
embeddings_params_initialized = True embeddings_params_initialized = True
@@ -43,19 +41,20 @@ def load_embedding_model(model: str):
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
except ModuleNotFoundError: except ModuleNotFoundError:
print("The sentence_transformers module has not been found. " + print(
"Please install it manually with " + "The sentence_transformers module has not been found. "
"pip install -U sentence-transformers.") + "Please install it manually with "
+ "pip install -U sentence-transformers."
)
raise ModuleNotFoundError from None raise ModuleNotFoundError from None
initialize_embedding_params() initialize_embedding_params()
global embeddings_device, embeddings_model global embeddings_device, embeddings_model
try: try:
print(f"Try embedding model: {model} on {embeddings_device}") print(f"Try embedding model: {model} on {embeddings_device}")
if 'jina-embeddings' in model: if "jina-embeddings" in model:
# trust_remote_code is needed to use the encode method # trust_remote_code is needed to use the encode method
embeddings_model = AutoModel.from_pretrained( embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True)
model, trust_remote_code=True)
embeddings_model = embeddings_model.to(embeddings_device) embeddings_model = embeddings_model.to(embeddings_device)
else: else:
embeddings_model = SentenceTransformer( embeddings_model = SentenceTransformer(
@@ -66,8 +65,9 @@ def load_embedding_model(model: str):
print(f"Loaded embedding model: {model}") print(f"Loaded embedding model: {model}")
except Exception as e: except Exception as e:
embeddings_model = None embeddings_model = None
raise Exception(f"Error: Failed to load embedding model: {model}", raise Exception(
internal_message=repr(e)) from None f"Error: Failed to load embedding model: {model}", internal_message=repr(e)
) from None
def get_embeddings_model(): def get_embeddings_model():
@@ -87,17 +87,17 @@ def get_embeddings_model_name() -> str:
def get_embeddings(input: list) -> np.ndarray: def get_embeddings(input: list) -> np.ndarray:
model = get_embeddings_model() model = get_embeddings_model()
embedding = model.encode(input, embedding = model.encode(
convert_to_numpy=True, input,
normalize_embeddings=True, convert_to_numpy=True,
convert_to_tensor=False, normalize_embeddings=True,
show_progress_bar=False) convert_to_tensor=False,
show_progress_bar=False,
)
return embedding return embedding
async def embeddings(input: list, async def embeddings(input: list, encoding_format: str, model: str = None) -> dict:
encoding_format: str,
model: str = None) -> dict:
if model is None: if model is None:
model = st_model model = st_model
else: else:
@@ -105,17 +105,15 @@ async def embeddings(input: list,
embeddings = get_embeddings(input) embeddings = get_embeddings(input)
if encoding_format == "base64": if encoding_format == "base64":
data = [{ data = [
"object": "embedding", {"object": "embedding", "embedding": float_list_to_base64(emb), "index": n}
"embedding": float_list_to_base64(emb), for n, emb in enumerate(embeddings)
"index": n ]
} for n, emb in enumerate(embeddings)]
else: else:
data = [{ data = [
"object": "embedding", {"object": "embedding", "embedding": emb.tolist(), "index": n}
"embedding": emb.tolist(), for n, emb in enumerate(embeddings)
"index": n ]
} for n, emb in enumerate(embeddings)]
response = { response = {
"object": "list", "object": "list",
@@ -124,7 +122,7 @@ async def embeddings(input: list,
"usage": { "usage": {
"prompt_tokens": 0, "prompt_tokens": 0,
"total_tokens": 0, "total_tokens": 0,
} },
} }
return response return response
@@ -140,6 +138,5 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
encoded_bytes = base64.b64encode(bytes_array) encoded_bytes = base64.b64encode(bytes_array)
# Turn raw base64 encoded bytes into ASCII # Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode('ascii') ascii_string = encoded_bytes.decode("ascii")
return ascii_string return ascii_string

View File

@@ -15,10 +15,7 @@ from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
) )
from endpoints.OAI.types.embedding import ( from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
EmbeddingsRequest,
EmbeddingsResponse
)
from endpoints.OAI.utils.chat_completion import ( from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template, format_prompt_with_template,
generate_chat_completion, generate_chat_completion,
@@ -132,19 +129,19 @@ async def chat_completion_request(
) )
return response return response
# Embeddings endpoint # Embeddings endpoint
@router.post( @router.post(
"/v1/embeddings", "/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_model_container)], dependencies=[Depends(check_api_key), Depends(check_model_container)],
response_model=EmbeddingsResponse response_model=EmbeddingsResponse,
) )
async def handle_embeddings(request: EmbeddingsRequest): async def handle_embeddings(request: EmbeddingsRequest):
input = request.input input = request.input
if not input: if not input:
raise JSONResponse(status_code=400, raise JSONResponse(
content={"error": "Missing required argument input"}) status_code=400, content={"error": "Missing required argument input"}
)
model = request.model if request.model else None model = request.model if request.model else None
response = await OAIembeddings.embeddings(input, request.encoding_format, response = await OAIembeddings.embeddings(input, request.encoding_format, model)
model)
return JSONResponse(response) return JSONResponse(response)

View File

@@ -8,32 +8,35 @@ class UsageInfo(BaseModel):
total_tokens: int = 0 total_tokens: int = 0
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):
input: List[str] = Field( input: List[str] = Field(
..., description="List of input texts to generate embeddings for.") ..., description="List of input texts to generate embeddings for."
)
encoding_format: str = Field( encoding_format: str = Field(
"float", "float",
description="Encoding format for the embeddings. " description="Encoding format for the embeddings. "
"Can be 'float' or 'base64'.") "Can be 'float' or 'base64'.",
)
model: Optional[str] = Field( model: Optional[str] = Field(
None, None,
description="Name of the embedding model to use. " description="Name of the embedding model to use. "
"If not provided, the default model will be used.") "If not provided, the default model will be used.",
)
class EmbeddingObject(BaseModel): class EmbeddingObject(BaseModel):
object: str = Field("embedding", description="Type of the object.") object: str = Field("embedding", description="Type of the object.")
embedding: List[float] = Field( embedding: List[float] = Field(
..., description="Embedding values as a list of floats.") ..., description="Embedding values as a list of floats."
)
index: int = Field( index: int = Field(
..., ..., description="Index of the input text corresponding to " "the embedding."
description="Index of the input text corresponding to " )
"the embedding.")
class EmbeddingsResponse(BaseModel): class EmbeddingsResponse(BaseModel):
object: str = Field("list", description="Type of the response object.") object: str = Field("list", description="Type of the response object.")
data: List[EmbeddingObject] = Field( data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
..., description="List of embedding objects.")
model: str = Field(..., description="Name of the embedding model used.") model: str = Field(..., description="Name of the embedding model used.")
usage: UsageInfo = Field(..., description="Information about token usage.") usage: UsageInfo = Field(..., description="Information about token usage.")