mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 10:11:39 +00:00
API: Fix disconnect handling on streaming responses
Starlette's StreamingResponse has an issue where it yields after a request has disconnected. A bugfix to starlette will fix this issue, but FastAPI uses starlette <= 0.36 which isn't ideal. Therefore, switch back to sse-starlette which handles these disconnects correctly. Also don't try yielding after the request is disconnected. Just return out of the generator instead. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -29,7 +29,7 @@ def get_generator_error(message: str, exc_info: bool = True):
|
|||||||
|
|
||||||
generator_error = handle_request_error(message)
|
generator_error = handle_request_error(message)
|
||||||
|
|
||||||
return get_sse_packet(generator_error.model_dump_json())
|
return generator_error.model_dump_json()
|
||||||
|
|
||||||
|
|
||||||
def handle_request_error(message: str, exc_info: bool = True):
|
def handle_request_error(message: str, exc_info: bool = True):
|
||||||
@@ -50,11 +50,6 @@ def handle_request_error(message: str, exc_info: bool = True):
|
|||||||
return request_error
|
return request_error
|
||||||
|
|
||||||
|
|
||||||
def get_sse_packet(json_data: str):
|
|
||||||
"""Get an SSE packet."""
|
|
||||||
return f"data: {json_data}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
def unwrap(wrapped, default=None):
|
def unwrap(wrapped, default=None):
|
||||||
"""Unwrap function for Optionals."""
|
"""Unwrap function for Optionals."""
|
||||||
if wrapped is None:
|
if wrapped is None:
|
||||||
|
|||||||
40
main.py
40
main.py
@@ -5,6 +5,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -13,7 +14,6 @@ from jinja2 import TemplateError
|
|||||||
from fastapi import FastAPI, Depends, HTTPException, Request
|
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@@ -47,7 +47,6 @@ from common.templating import (
|
|||||||
)
|
)
|
||||||
from common.utils import (
|
from common.utils import (
|
||||||
get_generator_error,
|
get_generator_error,
|
||||||
get_sse_packet,
|
|
||||||
handle_request_error,
|
handle_request_error,
|
||||||
load_progress,
|
load_progress,
|
||||||
unwrap,
|
unwrap,
|
||||||
@@ -235,12 +234,14 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||||||
progress.start()
|
progress.start()
|
||||||
|
|
||||||
for module, modules in load_status:
|
for module, modules in load_status:
|
||||||
|
|
||||||
|
# Get out if the request gets disconnected
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model load cancelled by user. "
|
"Model load cancelled by user. "
|
||||||
"Please make sure to run unload to free up resources."
|
"Please make sure to run unload to free up resources."
|
||||||
)
|
)
|
||||||
break
|
return
|
||||||
|
|
||||||
if module == 0:
|
if module == 0:
|
||||||
loading_task = progress.add_task(
|
loading_task = progress.add_task(
|
||||||
@@ -256,7 +257,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||||||
status="processing",
|
status="processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_sse_packet(response.model_dump_json())
|
yield response.model_dump_json()
|
||||||
|
|
||||||
if module == modules:
|
if module == modules:
|
||||||
response = ModelLoadResponse(
|
response = ModelLoadResponse(
|
||||||
@@ -266,7 +267,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||||||
status="finished",
|
status="finished",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_sse_packet(response.model_dump_json())
|
yield response.model_dump_json()
|
||||||
|
|
||||||
# Switch to model progress if the draft model is loaded
|
# Switch to model progress if the draft model is loaded
|
||||||
if model_type == "draft":
|
if model_type == "draft":
|
||||||
@@ -294,7 +295,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||||||
else:
|
else:
|
||||||
generator_callback = partial(generate_with_semaphore, generator)
|
generator_callback = partial(generate_with_semaphore, generator)
|
||||||
|
|
||||||
return StreamingResponse(generator_callback(), media_type="text/event-stream")
|
return EventSourceResponse(generator_callback())
|
||||||
|
|
||||||
|
|
||||||
# Unload model endpoint
|
# Unload model endpoint
|
||||||
@@ -515,31 +516,30 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||||||
if data.stream and not disable_request_streaming:
|
if data.stream and not disable_request_streaming:
|
||||||
|
|
||||||
async def generator():
|
async def generator():
|
||||||
"""Generator for the generation process."""
|
|
||||||
try:
|
try:
|
||||||
new_generation = MODEL_CONTAINER.generate_gen(
|
new_generation = MODEL_CONTAINER.generate_gen(
|
||||||
data.prompt, **data.to_gen_params()
|
data.prompt, **data.to_gen_params()
|
||||||
)
|
)
|
||||||
for generation in new_generation:
|
for generation in new_generation:
|
||||||
|
|
||||||
|
# Get out if the request gets disconnected
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
logger.error("Completion generation cancelled by user.")
|
logger.error("Completion generation cancelled by user.")
|
||||||
break
|
return
|
||||||
|
|
||||||
response = create_completion_response(generation, model_path.name)
|
response = create_completion_response(generation, model_path.name)
|
||||||
|
|
||||||
yield get_sse_packet(response.model_dump_json())
|
yield response.model_dump_json()
|
||||||
|
|
||||||
# Yield a finish response on successful generation
|
# Yield a finish response on successful generation
|
||||||
yield get_sse_packet("[DONE]")
|
yield "[DONE]"
|
||||||
except Exception:
|
except Exception:
|
||||||
yield get_generator_error(
|
yield get_generator_error(
|
||||||
"Completion aborted. Please check the server console."
|
"Completion aborted. Please check the server console."
|
||||||
)
|
)
|
||||||
|
print("Finished generation")
|
||||||
|
|
||||||
return StreamingResponse(
|
return EventSourceResponse(generate_with_semaphore(generator))
|
||||||
generate_with_semaphore(generator),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generation = await call_with_semaphore(
|
generation = await call_with_semaphore(
|
||||||
@@ -620,30 +620,30 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||||||
prompt, **data.to_gen_params()
|
prompt, **data.to_gen_params()
|
||||||
)
|
)
|
||||||
for generation in new_generation:
|
for generation in new_generation:
|
||||||
|
|
||||||
|
# Get out if the request gets disconnected
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
logger.error("Chat completion generation cancelled by user.")
|
logger.error("Chat completion generation cancelled by user.")
|
||||||
break
|
return
|
||||||
|
|
||||||
response = create_chat_completion_stream_chunk(
|
response = create_chat_completion_stream_chunk(
|
||||||
const_id, generation, model_path.name
|
const_id, generation, model_path.name
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_sse_packet(response.model_dump_json())
|
yield response.model_dump_json()
|
||||||
|
|
||||||
# Yield a finish response on successful generation
|
# Yield a finish response on successful generation
|
||||||
finish_response = create_chat_completion_stream_chunk(
|
finish_response = create_chat_completion_stream_chunk(
|
||||||
const_id, finish_reason="stop"
|
const_id, finish_reason="stop"
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_sse_packet(finish_response.model_dump_json())
|
yield finish_response.model_dump_json()
|
||||||
except Exception:
|
except Exception:
|
||||||
yield get_generator_error(
|
yield get_generator_error(
|
||||||
"Chat completion aborted. Please check the server console."
|
"Chat completion aborted. Please check the server console."
|
||||||
)
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return EventSourceResponse(generate_with_semaphore(generator))
|
||||||
generate_with_semaphore(generator), media_type="text/event-stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generation = await call_with_semaphore(
|
generation = await call_with_semaphore(
|
||||||
|
|||||||
Reference in New Issue
Block a user