diff --git a/main.py b/main.py index e5d7d26..a76de8c 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import uvicorn import yaml -import pathlib, os +import pathlib from auth import check_admin_key, check_api_key, load_auth_keys from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware @@ -24,7 +24,7 @@ from OAI.utils import ( create_chat_completion_stream_chunk ) from typing import Optional -from utils import load_progress +from utils import get_generator_error, load_progress from uuid import uuid4 app = FastAPI() @@ -102,38 +102,50 @@ async def load_model(data: ModelLoadRequest): model_container = ModelContainer(model_path.resolve(), False, **load_data) def generator(): + global model_container + + load_failed = False model_type = "draft" if model_container.draft_enabled else "model" load_status = model_container.load_gen(load_progress) - for (module, modules) in load_status: - if module == 0: - loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) - elif module == modules: - loading_bar.next() - loading_bar.finish() + # TODO: Maybe create an erroring generator as a common utility function + try: + for (module, modules) in load_status: + if module == 0: + loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + elif module == modules: + loading_bar.next() + loading_bar.finish() - response = ModelLoadResponse( - model_type=model_type, - module=module, - modules=modules, - status="finished" - ) + response = ModelLoadResponse( + model_type=model_type, + module=module, + modules=modules, + status="finished" + ) - yield response.json(ensure_ascii=False) + yield response.json(ensure_ascii=False) - if model_container.draft_enabled: - model_type = "model" - else: - loading_bar.next() - - response = ModelLoadResponse( - model_type=model_type, - module=module, - modules=modules, - status="processing" - ) + if model_container.draft_enabled: + model_type = "model" + else: + loading_bar.next() - yield response.json(ensure_ascii=False) + response = ModelLoadResponse( + model_type=model_type, + module=module, + modules=modules, + status="processing" + ) + + yield response.json(ensure_ascii=False) + except Exception as e: + yield get_generator_error(e) + load_failed = True + finally: + if load_failed: + model_container.unload() + model_container = None return EventSourceResponse(generator()) @@ -174,14 +186,17 @@ async def generate_completion(request: Request, data: CompletionRequest): if data.stream: async def generator(): - new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params()) - for part in new_generation: - if await request.is_disconnected(): - break + try: + new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params()) + for part in new_generation: + if await request.is_disconnected(): + break - response = create_completion_response(part, model_path.name) + response = create_completion_response(part, model_path.name) - yield response.json(ensure_ascii=False) + yield response.json(ensure_ascii=False) + except Exception as e: + yield get_generator_error(e) return EventSourceResponse(generator()) else: @@ -203,18 +218,21 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest if data.stream: const_id = f"chatcmpl-{uuid4().hex}" async def generator(): - new_generation = model_container.generate_gen(prompt, **data.to_gen_params()) - for part in new_generation: - if await request.is_disconnected(): - break + try: + new_generation = model_container.generate_gen(prompt, **data.to_gen_params()) + for part in new_generation: + if await request.is_disconnected(): + break - response = create_chat_completion_stream_chunk( - const_id, - part, - model_path.name - ) + response = create_chat_completion_stream_chunk( + const_id, + part, + model_path.name + ) - yield response.json(ensure_ascii=False) + yield response.json(ensure_ascii=False) + except Exception as e: + yield get_generator_error(e) return EventSourceResponse(generator()) else: diff --git a/utils.py b/utils.py index 1fa6283..dbfdb6b 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,29 @@ +import traceback +from pydantic import BaseModel +from typing import Optional + # Wrapper callback for load progress def load_progress(module, modules): - yield module, modules \ No newline at end of file + yield module, modules + +# Common error types +class TabbyGeneratorErrorMessage(BaseModel): + message: str + trace: Optional[str] = None + +class TabbyGeneratorError(BaseModel): + error: TabbyGeneratorErrorMessage + +def get_generator_error(exception: Exception): + error_message = TabbyGeneratorErrorMessage( + message = str(exception), + trace = traceback.format_exc() + ) + + generator_error = TabbyGeneratorError( + error = error_message + ) + + # Log and send the exception + print(f"\n{generator_error.error.trace}") + return generator_error.json()