From 8ba3bfa6b3d047a058023f39d8de69f85ff1190c Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 5 Dec 2023 00:23:15 -0500 Subject: [PATCH] API: Fix load exception handling Models do not fully unload if an exception is caught in load. Therefore, leave it to the client to unload on cancel. Also add handlers in the event a SSE stream is cancelled. These packets can't be sent back to the client since the client has severed the connection, so print them in terminal. Signed-off-by: kingbri --- main.py | 29 ++++++++++++++++------------- utils.py | 4 ++-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index 2e05beb..e5b9eb1 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import uvicorn import yaml import pathlib +from asyncio import CancelledError from auth import check_admin_key, check_api_key, load_auth_keys from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware @@ -77,7 +78,7 @@ async def get_current_model(): # Load model endpoint @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) -async def load_model(data: ModelLoadRequest): +async def load_model(request: Request, data: ModelLoadRequest): global model_container if model_container and model_container.model: @@ -104,18 +105,19 @@ async def load_model(data: ModelLoadRequest): model_container = ModelContainer(model_path.resolve(), False, **load_data) - def generator(): + async def generator(): global model_container - load_failed = False model_type = "draft" if model_container.draft_enabled else "model" load_status = model_container.load_gen(load_progress) - # TODO: Maybe create an erroring generator as a common utility function try: for (module, modules) in load_status: + if await request.is_disconnected(): + break + if module == 0: - loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) elif module == modules: loading_bar.next() loading_bar.finish() @@ -142,13 +144,10 @@ async def load_model(data: ModelLoadRequest): ) yield get_sse_packet(response.json(ensure_ascii=False)) + except CancelledError as e: + print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.") except Exception as e: - yield get_generator_error(e) - load_failed = True - finally: - if load_failed: - model_container.unload() - model_container = None + yield get_generator_error(str(e)) return StreamingResponse(generator(), media_type = "text/event-stream") @@ -201,8 +200,10 @@ async def generate_completion(request: Request, data: CompletionRequest): model_path.name) yield get_sse_packet(response.json(ensure_ascii=False)) + except CancelledError: + print("Error: Completion request cancelled by user.") except Exception as e: - yield get_generator_error(e) + yield get_generator_error(str(e)) return StreamingResponse( generate_with_semaphore(generator), @@ -251,8 +252,10 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest ) yield get_sse_packet(finish_response.json(ensure_ascii=False)) + except CancelledError: + print("Error: Chat completion cancelled by user.") except Exception as e: - yield get_generator_error(e) + yield get_generator_error(str(e)) return StreamingResponse( generate_with_semaphore(generator), diff --git a/utils.py b/utils.py index 2b730eb..623b27b 100644 --- a/utils.py +++ b/utils.py @@ -14,9 +14,9 @@ class TabbyGeneratorErrorMessage(BaseModel): class TabbyGeneratorError(BaseModel): error: TabbyGeneratorErrorMessage -def get_generator_error(exception: Exception): +def get_generator_error(message: str): error_message = TabbyGeneratorErrorMessage( - message = str(exception), + message = message, trace = traceback.format_exc() )