From ae69b185836604fdf0441c29f336fbaae391a6a5 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 1 Dec 2023 01:54:35 -0500 Subject: [PATCH] API: Use FastAPI streaming instead of sse_starlette sse_starlette kept firing a ping response if it was taking too long to set an event. Rather than using a hacky workaround, switch to FastAPI's inbuilt streaming response and construct SSE requests with a utility function. This helps the API become more robust and removes an extra requirement. Signed-off-by: kingbri --- main.py | 20 ++++++++++---------- model.py | 1 + requirements.txt | 1 - utils.py | 5 ++++- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index c817027..16ed0b0 100644 --- a/main.py +++ b/main.py @@ -4,9 +4,9 @@ import pathlib from auth import check_admin_key, check_api_key, load_auth_keys from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse from model import ModelContainer from progress.bar import IncrementalBar -from sse_starlette import EventSourceResponse from OAI.types.completion import CompletionRequest from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse @@ -24,7 +24,7 @@ from OAI.utils import ( create_chat_completion_stream_chunk ) from typing import Optional -from utils import get_generator_error, load_progress +from utils import get_generator_error, get_sse_packet, load_progress from uuid import uuid4 app = FastAPI() @@ -126,7 +126,7 @@ async def load_model(data: ModelLoadRequest): status="finished" ) - yield response.json(ensure_ascii=False) + yield get_sse_packet(response.json(ensure_ascii=False)) if model_container.draft_enabled: model_type = "model" @@ -140,7 +140,7 @@ async def load_model(data: ModelLoadRequest): status="processing" ) - yield response.json(ensure_ascii=False) + yield get_sse_packet(response.json(ensure_ascii=False)) except Exception as e: yield get_generator_error(e) load_failed = True @@ -149,7 +149,7 @@ async def load_model(data: ModelLoadRequest): model_container.unload() model_container = None - return EventSourceResponse(generator()) + return StreamingResponse(generator(), media_type = "text/event-stream") # Unload model endpoint @app.get("/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)]) @@ -199,11 +199,11 @@ async def generate_completion(request: Request, data: CompletionRequest): completion_tokens, model_path.name) - yield response.json(ensure_ascii=False) + yield get_sse_packet(response.json(ensure_ascii=False)) except Exception as e: yield get_generator_error(e) - return EventSourceResponse(generator()) + return StreamingResponse(generator(), media_type = "text/event-stream") else: response_text, prompt_tokens, completion_tokens = model_container.generate(data.prompt, **data.to_gen_params()) response = create_completion_response(response_text, @@ -238,7 +238,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest model_path.name ) - yield response.json(ensure_ascii=False) + yield get_sse_packet(response.json(ensure_ascii=False)) except Exception as e: yield get_generator_error(e) finally: @@ -249,9 +249,9 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest finish_reason = "stop" ) - yield finish_response.json(ensure_ascii=False) + yield get_sse_packet(finish_response.json(ensure_ascii=False)) - return EventSourceResponse(generator()) + return StreamingResponse(generator(), media_type = "text/event-stream") else: response_text, prompt_tokens, completion_tokens = model_container.generate(prompt, **data.to_gen_params()) response = create_chat_completion_response(response_text, diff --git a/model.py b/model.py index 3b2dc8d..1eb90ae 100644 --- a/model.py +++ b/model.py @@ -311,6 +311,7 @@ class ModelContainer: stop_conditions: List[Union[str, int]] = kwargs.get("stop", []) ban_eos_token = kwargs.get("ban_eos_token", False) + # Ban the EOS token if specified. If not, append to stop conditions as well. if ban_eos_token: diff --git a/requirements.txt b/requirements.txt index fbe3a2a..66b172f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ pydantic < 2,>= 1 PyYAML progress -sse_starlette uvicorn # Wheels diff --git a/utils.py b/utils.py index dbfdb6b..2b730eb 100644 --- a/utils.py +++ b/utils.py @@ -26,4 +26,7 @@ def get_generator_error(exception: Exception): # Log and send the exception print(f"\n{generator_error.error.trace}") - return generator_error.json() + return get_sse_packet(generator_error.json(ensure_ascii = False)) + +def get_sse_packet(json_data: str): + return f"data: {json_data}\n\n"