mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 18:21:42 +00:00
ruff: formatting
This commit is contained in:
@@ -16,24 +16,22 @@ embeddings_params_initialized = False
|
|||||||
|
|
||||||
|
|
||||||
def initialize_embedding_params():
|
def initialize_embedding_params():
|
||||||
'''
|
"""
|
||||||
using 'lazy loading' to avoid circular import
|
using 'lazy loading' to avoid circular import
|
||||||
so this function will be executed only once
|
so this function will be executed only once
|
||||||
'''
|
"""
|
||||||
global embeddings_params_initialized
|
global embeddings_params_initialized
|
||||||
if not embeddings_params_initialized:
|
if not embeddings_params_initialized:
|
||||||
|
|
||||||
global st_model, embeddings_model, embeddings_device
|
global st_model, embeddings_model, embeddings_device
|
||||||
|
|
||||||
st_model = os.environ.get("OPENAI_EMBEDDING_MODEL",
|
st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", "all-mpnet-base-v2")
|
||||||
'all-mpnet-base-v2')
|
|
||||||
embeddings_model = None
|
embeddings_model = None
|
||||||
# OPENAI_EMBEDDING_DEVICE: auto (best or cpu),
|
# OPENAI_EMBEDDING_DEVICE: auto (best or cpu),
|
||||||
# cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep,
|
# cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep,
|
||||||
# hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta,
|
# hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta,
|
||||||
# hpu, mtia, privateuseone
|
# hpu, mtia, privateuseone
|
||||||
embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", 'cpu')
|
embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", "cpu")
|
||||||
if embeddings_device.lower() == 'auto':
|
if embeddings_device.lower() == "auto":
|
||||||
embeddings_device = None
|
embeddings_device = None
|
||||||
|
|
||||||
embeddings_params_initialized = True
|
embeddings_params_initialized = True
|
||||||
@@ -43,19 +41,20 @@ def load_embedding_model(model: str):
|
|||||||
try:
|
try:
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
print("The sentence_transformers module has not been found. " +
|
print(
|
||||||
"Please install it manually with " +
|
"The sentence_transformers module has not been found. "
|
||||||
"pip install -U sentence-transformers.")
|
+ "Please install it manually with "
|
||||||
|
+ "pip install -U sentence-transformers."
|
||||||
|
)
|
||||||
raise ModuleNotFoundError from None
|
raise ModuleNotFoundError from None
|
||||||
|
|
||||||
initialize_embedding_params()
|
initialize_embedding_params()
|
||||||
global embeddings_device, embeddings_model
|
global embeddings_device, embeddings_model
|
||||||
try:
|
try:
|
||||||
print(f"Try embedding model: {model} on {embeddings_device}")
|
print(f"Try embedding model: {model} on {embeddings_device}")
|
||||||
if 'jina-embeddings' in model:
|
if "jina-embeddings" in model:
|
||||||
# trust_remote_code is needed to use the encode method
|
# trust_remote_code is needed to use the encode method
|
||||||
embeddings_model = AutoModel.from_pretrained(
|
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True)
|
||||||
model, trust_remote_code=True)
|
|
||||||
embeddings_model = embeddings_model.to(embeddings_device)
|
embeddings_model = embeddings_model.to(embeddings_device)
|
||||||
else:
|
else:
|
||||||
embeddings_model = SentenceTransformer(
|
embeddings_model = SentenceTransformer(
|
||||||
@@ -66,8 +65,9 @@ def load_embedding_model(model: str):
|
|||||||
print(f"Loaded embedding model: {model}")
|
print(f"Loaded embedding model: {model}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
embeddings_model = None
|
embeddings_model = None
|
||||||
raise Exception(f"Error: Failed to load embedding model: {model}",
|
raise Exception(
|
||||||
internal_message=repr(e)) from None
|
f"Error: Failed to load embedding model: {model}", internal_message=repr(e)
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
def get_embeddings_model():
|
def get_embeddings_model():
|
||||||
@@ -87,17 +87,17 @@ def get_embeddings_model_name() -> str:
|
|||||||
|
|
||||||
def get_embeddings(input: list) -> np.ndarray:
|
def get_embeddings(input: list) -> np.ndarray:
|
||||||
model = get_embeddings_model()
|
model = get_embeddings_model()
|
||||||
embedding = model.encode(input,
|
embedding = model.encode(
|
||||||
convert_to_numpy=True,
|
input,
|
||||||
normalize_embeddings=True,
|
convert_to_numpy=True,
|
||||||
convert_to_tensor=False,
|
normalize_embeddings=True,
|
||||||
show_progress_bar=False)
|
convert_to_tensor=False,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
async def embeddings(input: list,
|
async def embeddings(input: list, encoding_format: str, model: str = None) -> dict:
|
||||||
encoding_format: str,
|
|
||||||
model: str = None) -> dict:
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model = st_model
|
model = st_model
|
||||||
else:
|
else:
|
||||||
@@ -105,17 +105,15 @@ async def embeddings(input: list,
|
|||||||
|
|
||||||
embeddings = get_embeddings(input)
|
embeddings = get_embeddings(input)
|
||||||
if encoding_format == "base64":
|
if encoding_format == "base64":
|
||||||
data = [{
|
data = [
|
||||||
"object": "embedding",
|
{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n}
|
||||||
"embedding": float_list_to_base64(emb),
|
for n, emb in enumerate(embeddings)
|
||||||
"index": n
|
]
|
||||||
} for n, emb in enumerate(embeddings)]
|
|
||||||
else:
|
else:
|
||||||
data = [{
|
data = [
|
||||||
"object": "embedding",
|
{"object": "embedding", "embedding": emb.tolist(), "index": n}
|
||||||
"embedding": emb.tolist(),
|
for n, emb in enumerate(embeddings)
|
||||||
"index": n
|
]
|
||||||
} for n, emb in enumerate(embeddings)]
|
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
@@ -124,7 +122,7 @@ async def embeddings(input: list,
|
|||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
"total_tokens": 0,
|
"total_tokens": 0,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -140,6 +138,5 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
|
|||||||
encoded_bytes = base64.b64encode(bytes_array)
|
encoded_bytes = base64.b64encode(bytes_array)
|
||||||
|
|
||||||
# Turn raw base64 encoded bytes into ASCII
|
# Turn raw base64 encoded bytes into ASCII
|
||||||
ascii_string = encoded_bytes.decode('ascii')
|
ascii_string = encoded_bytes.decode("ascii")
|
||||||
return ascii_string
|
return ascii_string
|
||||||
|
|
||||||
|
|||||||
@@ -15,10 +15,7 @@ from endpoints.OAI.types.chat_completion import (
|
|||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
)
|
)
|
||||||
from endpoints.OAI.types.embedding import (
|
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
|
||||||
EmbeddingsRequest,
|
|
||||||
EmbeddingsResponse
|
|
||||||
)
|
|
||||||
from endpoints.OAI.utils.chat_completion import (
|
from endpoints.OAI.utils.chat_completion import (
|
||||||
format_prompt_with_template,
|
format_prompt_with_template,
|
||||||
generate_chat_completion,
|
generate_chat_completion,
|
||||||
@@ -132,19 +129,19 @@ async def chat_completion_request(
|
|||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# Embeddings endpoint
|
# Embeddings endpoint
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/embeddings",
|
"/v1/embeddings",
|
||||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
response_model=EmbeddingsResponse
|
response_model=EmbeddingsResponse,
|
||||||
)
|
)
|
||||||
async def handle_embeddings(request: EmbeddingsRequest):
|
async def handle_embeddings(request: EmbeddingsRequest):
|
||||||
input = request.input
|
input = request.input
|
||||||
if not input:
|
if not input:
|
||||||
raise JSONResponse(status_code=400,
|
raise JSONResponse(
|
||||||
content={"error": "Missing required argument input"})
|
status_code=400, content={"error": "Missing required argument input"}
|
||||||
|
)
|
||||||
model = request.model if request.model else None
|
model = request.model if request.model else None
|
||||||
response = await OAIembeddings.embeddings(input, request.encoding_format,
|
response = await OAIembeddings.embeddings(input, request.encoding_format, model)
|
||||||
model)
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|||||||
@@ -8,32 +8,35 @@ class UsageInfo(BaseModel):
|
|||||||
total_tokens: int = 0
|
total_tokens: int = 0
|
||||||
completion_tokens: Optional[int] = 0
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequest(BaseModel):
|
class EmbeddingsRequest(BaseModel):
|
||||||
input: List[str] = Field(
|
input: List[str] = Field(
|
||||||
..., description="List of input texts to generate embeddings for.")
|
..., description="List of input texts to generate embeddings for."
|
||||||
|
)
|
||||||
encoding_format: str = Field(
|
encoding_format: str = Field(
|
||||||
"float",
|
"float",
|
||||||
description="Encoding format for the embeddings. "
|
description="Encoding format for the embeddings. "
|
||||||
"Can be 'float' or 'base64'.")
|
"Can be 'float' or 'base64'.",
|
||||||
|
)
|
||||||
model: Optional[str] = Field(
|
model: Optional[str] = Field(
|
||||||
None,
|
None,
|
||||||
description="Name of the embedding model to use. "
|
description="Name of the embedding model to use. "
|
||||||
"If not provided, the default model will be used.")
|
"If not provided, the default model will be used.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingObject(BaseModel):
|
class EmbeddingObject(BaseModel):
|
||||||
object: str = Field("embedding", description="Type of the object.")
|
object: str = Field("embedding", description="Type of the object.")
|
||||||
embedding: List[float] = Field(
|
embedding: List[float] = Field(
|
||||||
..., description="Embedding values as a list of floats.")
|
..., description="Embedding values as a list of floats."
|
||||||
|
)
|
||||||
index: int = Field(
|
index: int = Field(
|
||||||
...,
|
..., description="Index of the input text corresponding to " "the embedding."
|
||||||
description="Index of the input text corresponding to "
|
)
|
||||||
"the embedding.")
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsResponse(BaseModel):
|
class EmbeddingsResponse(BaseModel):
|
||||||
object: str = Field("list", description="Type of the response object.")
|
object: str = Field("list", description="Type of the response object.")
|
||||||
data: List[EmbeddingObject] = Field(
|
data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
|
||||||
..., description="List of embedding objects.")
|
|
||||||
model: str = Field(..., description="Name of the embedding model used.")
|
model: str = Field(..., description="Name of the embedding model used.")
|
||||||
usage: UsageInfo = Field(..., description="Information about token usage.")
|
usage: UsageInfo = Field(..., description="Information about token usage.")
|
||||||
|
|||||||
Reference in New Issue
Block a user