OAI: Add chat completions endpoint

Chat completions is the endpoint that will be used by OAI in the
future. Makes sense to support it even though the completions
endpoint will be used more often.

Also unify common parameters between the chat completion and completion
requests since they're very similar.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-16 01:06:07 -05:00
parent 593471a04d
commit 5e8419ec0c
5 changed files with 247 additions and 96 deletions

61
main.py
View File

@@ -7,11 +7,19 @@ from model import ModelContainer
from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse
from OAI.utils import create_completion_response, get_model_list
from OAI.utils import (
create_completion_response,
get_model_list,
get_chat_completion_prompt,
create_chat_completion_response,
create_chat_completion_stream_chunk
)
from typing import Optional
from utils import load_progress
from uuid import uuid4
app = FastAPI()
@@ -45,8 +53,8 @@ async def load_model(data: ModelLoadRequest):
model_config = config["model"]
model_path = pathlib.Path(model_config["model_dir"] or "models")
model_path = model_path / data.name
model_container = ModelContainer(model_path, False, **data.model_dump())
model_container = ModelContainer(model_path, False, **data.dict())
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
@@ -97,21 +105,58 @@ async def decode_tokens(data: TokenDecodeRequest):
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
async def generate_completion(request: Request, data: CompletionRequest):
model_path = model_container.get_model_path()
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
if data.stream:
async def generator():
new_generation = model_container.generate_gen(**data.to_gen_params())
for index, part in enumerate(new_generation):
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
for part in new_generation:
if await request.is_disconnected():
break
response = create_completion_response(part, index, model_container.get_model_path().name)
response = create_completion_response(part, model_path.name)
yield response.json()
return EventSourceResponse(generator())
else:
response_text = model_container.generate(**data.to_gen_params())
response = create_completion_response(response_text, 0, model_container.get_model_path().name)
response_text = model_container.generate(data.prompt, **data.to_gen_params())
response = create_completion_response(response_text, model_path.name)
return response.json()
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
model_path = model_container.get_model_path()
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = get_chat_completion_prompt(model_path.name, data.messages)
if data.stream:
const_id = f"chatcmpl-{uuid4().hex}"
async def generator():
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
for part in new_generation:
if await request.is_disconnected():
break
response = create_chat_completion_stream_chunk(
const_id,
part,
model_path.name
)
yield response.json()
return EventSourceResponse(generator())
else:
response_text = model_container.generate(prompt, **data.to_gen_params())
response = create_chat_completion_response(response_text, model_path.name)
return response.json()