mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
OAI: Add chat completions endpoint
Chat completions is the endpoint that will be used by OAI in the future. Makes sense to support it even though the completions endpoint will be used more often. Also unify common parameters between the chat completion and completion requests since they're very similar. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
61
main.py
61
main.py
@@ -7,11 +7,19 @@ from model import ModelContainer
|
||||
from progress.bar import IncrementalBar
|
||||
from sse_starlette import EventSourceResponse
|
||||
from OAI.types.completion import CompletionRequest
|
||||
from OAI.types.chat_completion import ChatCompletionRequest
|
||||
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
|
||||
from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse
|
||||
from OAI.utils import create_completion_response, get_model_list
|
||||
from OAI.utils import (
|
||||
create_completion_response,
|
||||
get_model_list,
|
||||
get_chat_completion_prompt,
|
||||
create_chat_completion_response,
|
||||
create_chat_completion_stream_chunk
|
||||
)
|
||||
from typing import Optional
|
||||
from utils import load_progress
|
||||
from uuid import uuid4
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -45,8 +53,8 @@ async def load_model(data: ModelLoadRequest):
|
||||
model_config = config["model"]
|
||||
model_path = pathlib.Path(model_config["model_dir"] or "models")
|
||||
model_path = model_path / data.name
|
||||
|
||||
model_container = ModelContainer(model_path, False, **data.model_dump())
|
||||
|
||||
model_container = ModelContainer(model_path, False, **data.dict())
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
@@ -97,21 +105,58 @@ async def decode_tokens(data: TokenDecodeRequest):
|
||||
|
||||
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
model_path = model_container.get_model_path()
|
||||
|
||||
if isinstance(data.prompt, list):
|
||||
data.prompt = "\n".join(data.prompt)
|
||||
|
||||
if data.stream:
|
||||
async def generator():
|
||||
new_generation = model_container.generate_gen(**data.to_gen_params())
|
||||
for index, part in enumerate(new_generation):
|
||||
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_completion_response(part, index, model_container.get_model_path().name)
|
||||
response = create_completion_response(part, model_path.name)
|
||||
|
||||
yield response.json()
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(**data.to_gen_params())
|
||||
response = create_completion_response(response_text, 0, model_container.get_model_path().name)
|
||||
response_text = model_container.generate(data.prompt, **data.to_gen_params())
|
||||
response = create_completion_response(response_text, model_path.name)
|
||||
|
||||
return response.json()
|
||||
|
||||
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
||||
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
||||
model_path = model_container.get_model_path()
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
prompt = get_chat_completion_prompt(model_path.name, data.messages)
|
||||
|
||||
if data.stream:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
async def generator():
|
||||
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_chat_completion_stream_chunk(
|
||||
const_id,
|
||||
part,
|
||||
model_path.name
|
||||
)
|
||||
|
||||
yield response.json()
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(prompt, **data.to_gen_params())
|
||||
response = create_chat_completion_response(response_text, model_path.name)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user