OAI: Implement completion API endpoint

Add support for /v1/completions with the option to use streaming
if needed. Also rewrite API endpoints to use async when possible
since that improves request performance.

Model container parameter names also needed rewrites as well and
set fallback cases to their disabled values.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-13 18:24:12 -05:00
parent 4fa4386275
commit eee8b642bd
6 changed files with 190 additions and 57 deletions

51
main.py
View File

@@ -1,41 +1,37 @@
import uvicorn
import yaml
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi import FastAPI, Request
from model import ModelContainer
from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse
from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice
from OAI.utils import create_completion_response
app = FastAPI()
# Initialize a model container. This can be undefined at any period of time
model_container: ModelContainer = None
class TextRequest(BaseModel):
model: str = None # Make the "model" field optional with a default value of None
prompt: str
max_tokens: int = 200
temperature: float = 1
top_p: float = 0.9
seed: int = 10
stream: bool = False
token_repetition_penalty: float = 1.0
stop: list = None
@app.post("/v1/completions")
async def generate_completion(request: Request, data: CompletionRequest):
if data.stream:
async def generator():
new_generation = model_container.generate_gen(**data.to_gen_params())
for index, part in enumerate(new_generation):
if await request.is_disconnected():
break
class TextResponse(BaseModel):
response: str
generation_time: str
response = create_completion_response(part, index, model_container.get_model_name())
yield response.model_dump_json()
return EventSourceResponse(generator())
else:
response_text = model_container.generate(**data.to_gen_params())
response = create_completion_response(response_text, 0, model_container.get_model_name())
return response.model_dump_json()
# TODO: Currently broken
@app.post("/generate-text", response_model=TextResponse)
def generate_text(request: TextRequest):
global modelManager
try:
prompt = request.prompt # Get the prompt from the request
user_message = prompt # Assuming that prompt is equivalent to the user's message
output, generation_time = modelManager.generate_text(prompt=user_message)
return {"response": output, "generation_time": generation_time}
except RuntimeError as e:
raise HTTPException(status_code=500, detail=str(e))
# Wrapper callback for load progress
def load_progress(module, modules):
@@ -63,5 +59,4 @@ if __name__ == "__main__":
print("Model successfully loaded.")
# Reload is for dev purposes ONLY!
uvicorn.run("main:app", host="0.0.0.0", port=8012, log_level="debug", reload=True)
uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug")