Templates: Add error handling for template errors

Similar to the transformers library, add an error handler when an
exception is fired. This relays the error to the user.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-22 11:59:47 -05:00
parent fa47f51f85
commit 71f6a586f1
2 changed files with 19 additions and 5 deletions

14
main.py
View File

@@ -3,6 +3,7 @@ import pathlib
from asyncio import CancelledError
from typing import Optional
from uuid import uuid4
from jinja2 import TemplateError
import uvicorn
import yaml
@@ -395,7 +396,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
"""Generates a chat completion from a prompt."""
if MODEL_CONTAINER.prompt_template is None:
return HTTPException(
raise HTTPException(
422,
"This endpoint is disabled because a prompt template is not set.",
)
@@ -416,13 +417,18 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
data.add_generation_prompt,
special_tokens_dict,
)
except KeyError:
return HTTPException(
except KeyError as exc:
raise HTTPException(
400,
"Could not find a Conversation from prompt template "
f"'{MODEL_CONTAINER.prompt_template.name}'. "
"Check your spelling?",
)
) from exc
except TemplateError as exc:
raise HTTPException(
400,
f"TemplateError: {str(exc)}",
) from exc
if data.stream:
const_id = f"chatcmpl-{uuid4().hex}"

View File

@@ -3,7 +3,7 @@ import json
import pathlib
from functools import lru_cache
from importlib.metadata import version as package_version
from jinja2 import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from packaging import version
from pydantic import BaseModel
@@ -44,7 +44,15 @@ def get_prompt_from_template(
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
@lru_cache
def _compile_template(template: str):
"""Compiles a Jinja2 template"""
# Exception handler
def raise_exception(message):
raise TemplateError(message)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env.globals["raise_exception"] = raise_exception
jinja_template = jinja_env.from_string(template)
return jinja_template