diff --git a/main.py b/main.py index 2e05beb..e5b9eb1 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import uvicorn import yaml import pathlib +from asyncio import CancelledError from auth import check_admin_key, check_api_key, load_auth_keys from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware @@ -77,7 +78,7 @@ async def get_current_model(): # Load model endpoint @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) -async def load_model(data: ModelLoadRequest): +async def load_model(request: Request, data: ModelLoadRequest): global model_container if model_container and model_container.model: @@ -104,18 +105,19 @@ async def load_model(data: ModelLoadRequest): model_container = ModelContainer(model_path.resolve(), False, **load_data) - def generator(): + async def generator(): global model_container - load_failed = False model_type = "draft" if model_container.draft_enabled else "model" load_status = model_container.load_gen(load_progress) - # TODO: Maybe create an erroring generator as a common utility function try: for (module, modules) in load_status: + if await request.is_disconnected(): + break + if module == 0: - loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) elif module == modules: loading_bar.next() loading_bar.finish() @@ -142,13 +144,10 @@ async def load_model(data: ModelLoadRequest): ) yield get_sse_packet(response.json(ensure_ascii=False)) + except CancelledError as e: + print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.") except Exception as e: - yield get_generator_error(e) - load_failed = True - finally: - if load_failed: - model_container.unload() - model_container = None + yield get_generator_error(str(e)) return StreamingResponse(generator(), media_type = "text/event-stream") @@ -201,8 +200,10 @@ async def generate_completion(request: Request, data: CompletionRequest): model_path.name) yield get_sse_packet(response.json(ensure_ascii=False)) + except CancelledError: + print("Error: Completion request cancelled by user.") except Exception as e: - yield get_generator_error(e) + yield get_generator_error(str(e)) return StreamingResponse( generate_with_semaphore(generator), @@ -251,8 +252,10 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest ) yield get_sse_packet(finish_response.json(ensure_ascii=False)) + except CancelledError: + print("Error: Chat completion cancelled by user.") except Exception as e: - yield get_generator_error(e) + yield get_generator_error(str(e)) return StreamingResponse( generate_with_semaphore(generator), diff --git a/utils.py b/utils.py index 2b730eb..623b27b 100644 --- a/utils.py +++ b/utils.py @@ -14,9 +14,9 @@ class TabbyGeneratorErrorMessage(BaseModel): class TabbyGeneratorError(BaseModel): error: TabbyGeneratorErrorMessage -def get_generator_error(exception: Exception): +def get_generator_error(message: str): error_message = TabbyGeneratorErrorMessage( - message = str(exception), + message = message, trace = traceback.format_exc() )