API: Add generator error handling

If the generator errors, there's no proper handling to send an error
packet and close the connection.

This is especially important for unloading models if the load fails
at any stage to reclaim a user's VRAM. Raising an exception caused
the model_container object to lock and not get freed by the GC.

This made sense to propegate SSE errors across all generator functions
rather than relying on abort signals.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-30 00:37:48 -05:00
parent 2bc3da0155
commit 56f9b1d1a8
2 changed files with 88 additions and 44 deletions

104
main.py
View File

@@ -1,6 +1,6 @@
import uvicorn import uvicorn
import yaml import yaml
import pathlib, os import pathlib
from auth import check_admin_key, check_api_key, load_auth_keys from auth import check_admin_key, check_api_key, load_auth_keys
from fastapi import FastAPI, Request, HTTPException, Depends from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@@ -24,7 +24,7 @@ from OAI.utils import (
create_chat_completion_stream_chunk create_chat_completion_stream_chunk
) )
from typing import Optional from typing import Optional
from utils import load_progress from utils import get_generator_error, load_progress
from uuid import uuid4 from uuid import uuid4
app = FastAPI() app = FastAPI()
@@ -102,38 +102,50 @@ async def load_model(data: ModelLoadRequest):
model_container = ModelContainer(model_path.resolve(), False, **load_data) model_container = ModelContainer(model_path.resolve(), False, **load_data)
def generator(): def generator():
global model_container
load_failed = False
model_type = "draft" if model_container.draft_enabled else "model" model_type = "draft" if model_container.draft_enabled else "model"
load_status = model_container.load_gen(load_progress) load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status: # TODO: Maybe create an erroring generator as a common utility function
if module == 0: try:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) for (module, modules) in load_status:
elif module == modules: if module == 0:
loading_bar.next() loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
loading_bar.finish() elif module == modules:
loading_bar.next()
loading_bar.finish()
response = ModelLoadResponse( response = ModelLoadResponse(
model_type=model_type, model_type=model_type,
module=module, module=module,
modules=modules, modules=modules,
status="finished" status="finished"
) )
yield response.json(ensure_ascii=False) yield response.json(ensure_ascii=False)
if model_container.draft_enabled: if model_container.draft_enabled:
model_type = "model" model_type = "model"
else: else:
loading_bar.next() loading_bar.next()
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing"
)
yield response.json(ensure_ascii=False) response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing"
)
yield response.json(ensure_ascii=False)
except Exception as e:
yield get_generator_error(e)
load_failed = True
finally:
if load_failed:
model_container.unload()
model_container = None
return EventSourceResponse(generator()) return EventSourceResponse(generator())
@@ -174,14 +186,17 @@ async def generate_completion(request: Request, data: CompletionRequest):
if data.stream: if data.stream:
async def generator(): async def generator():
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params()) try:
for part in new_generation: new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
if await request.is_disconnected(): for part in new_generation:
break if await request.is_disconnected():
break
response = create_completion_response(part, model_path.name) response = create_completion_response(part, model_path.name)
yield response.json(ensure_ascii=False) yield response.json(ensure_ascii=False)
except Exception as e:
yield get_generator_error(e)
return EventSourceResponse(generator()) return EventSourceResponse(generator())
else: else:
@@ -203,18 +218,21 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
if data.stream: if data.stream:
const_id = f"chatcmpl-{uuid4().hex}" const_id = f"chatcmpl-{uuid4().hex}"
async def generator(): async def generator():
new_generation = model_container.generate_gen(prompt, **data.to_gen_params()) try:
for part in new_generation: new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
if await request.is_disconnected(): for part in new_generation:
break if await request.is_disconnected():
break
response = create_chat_completion_stream_chunk( response = create_chat_completion_stream_chunk(
const_id, const_id,
part, part,
model_path.name model_path.name
) )
yield response.json(ensure_ascii=False) yield response.json(ensure_ascii=False)
except Exception as e:
yield get_generator_error(e)
return EventSourceResponse(generator()) return EventSourceResponse(generator())
else: else:

View File

@@ -1,3 +1,29 @@
import traceback
from pydantic import BaseModel
from typing import Optional
# Wrapper callback for load progress # Wrapper callback for load progress
def load_progress(module, modules): def load_progress(module, modules):
yield module, modules yield module, modules
# Common error types
class TabbyGeneratorErrorMessage(BaseModel):
message: str
trace: Optional[str] = None
class TabbyGeneratorError(BaseModel):
error: TabbyGeneratorErrorMessage
def get_generator_error(exception: Exception):
error_message = TabbyGeneratorErrorMessage(
message = str(exception),
trace = traceback.format_exc()
)
generator_error = TabbyGeneratorError(
error = error_message
)
# Log and send the exception
print(f"\n{generator_error.error.trace}")
return generator_error.json()