mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
OAI: Implement completion API endpoint
Add support for /v1/completions with the option to use streaming if needed. Also rewrite API endpoints to use async when possible since that improves request performance. Model container parameter names also needed rewrites as well and set fallback cases to their disabled values. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
51
main.py
51
main.py
@@ -1,41 +1,37 @@
|
||||
import uvicorn
|
||||
import yaml
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, Request
|
||||
from model import ModelContainer
|
||||
from progress.bar import IncrementalBar
|
||||
from sse_starlette import EventSourceResponse
|
||||
from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice
|
||||
from OAI.utils import create_completion_response
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize a model container. This can be undefined at any period of time
|
||||
model_container: ModelContainer = None
|
||||
|
||||
class TextRequest(BaseModel):
|
||||
model: str = None # Make the "model" field optional with a default value of None
|
||||
prompt: str
|
||||
max_tokens: int = 200
|
||||
temperature: float = 1
|
||||
top_p: float = 0.9
|
||||
seed: int = 10
|
||||
stream: bool = False
|
||||
token_repetition_penalty: float = 1.0
|
||||
stop: list = None
|
||||
@app.post("/v1/completions")
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
if data.stream:
|
||||
async def generator():
|
||||
new_generation = model_container.generate_gen(**data.to_gen_params())
|
||||
for index, part in enumerate(new_generation):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
class TextResponse(BaseModel):
|
||||
response: str
|
||||
generation_time: str
|
||||
response = create_completion_response(part, index, model_container.get_model_name())
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(**data.to_gen_params())
|
||||
response = create_completion_response(response_text, 0, model_container.get_model_name())
|
||||
|
||||
return response.model_dump_json()
|
||||
|
||||
# TODO: Currently broken
|
||||
@app.post("/generate-text", response_model=TextResponse)
|
||||
def generate_text(request: TextRequest):
|
||||
global modelManager
|
||||
try:
|
||||
prompt = request.prompt # Get the prompt from the request
|
||||
user_message = prompt # Assuming that prompt is equivalent to the user's message
|
||||
output, generation_time = modelManager.generate_text(prompt=user_message)
|
||||
return {"response": output, "generation_time": generation_time}
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Wrapper callback for load progress
|
||||
def load_progress(module, modules):
|
||||
@@ -63,5 +59,4 @@ if __name__ == "__main__":
|
||||
|
||||
print("Model successfully loaded.")
|
||||
|
||||
# Reload is for dev purposes ONLY!
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8012, log_level="debug", reload=True)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug")
|
||||
|
||||
Reference in New Issue
Block a user