diff --git a/main.py b/main.py index c497999..fb1fe26 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ import pathlib from asyncio import CancelledError from typing import Optional from uuid import uuid4 +from jinja2 import TemplateError import uvicorn import yaml @@ -395,7 +396,7 @@ async def generate_completion(request: Request, data: CompletionRequest): async def generate_chat_completion(request: Request, data: ChatCompletionRequest): """Generates a chat completion from a prompt.""" if MODEL_CONTAINER.prompt_template is None: - return HTTPException( + raise HTTPException( 422, "This endpoint is disabled because a prompt template is not set.", ) @@ -416,13 +417,18 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest data.add_generation_prompt, special_tokens_dict, ) - except KeyError: - return HTTPException( + except KeyError as exc: + raise HTTPException( 400, "Could not find a Conversation from prompt template " f"'{MODEL_CONTAINER.prompt_template.name}'. " "Check your spelling?", - ) + ) from exc + except TemplateError as exc: + raise HTTPException( + 400, + f"TemplateError: {str(exc)}", + ) from exc if data.stream: const_id = f"chatcmpl-{uuid4().hex}" diff --git a/templating.py b/templating.py index f1a86e3..fb4f030 100644 --- a/templating.py +++ b/templating.py @@ -3,7 +3,7 @@ import json import pathlib from functools import lru_cache from importlib.metadata import version as package_version - +from jinja2 import TemplateError from jinja2.sandbox import ImmutableSandboxedEnvironment from packaging import version from pydantic import BaseModel @@ -44,7 +44,15 @@ def get_prompt_from_template( # https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 @lru_cache def _compile_template(template: str): + """Compiles a Jinja2 template""" + + # Exception handler + def raise_exception(message): + raise TemplateError(message) + jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) + jinja_env.globals["raise_exception"] = raise_exception + jinja_template = jinja_env.from_string(template) return jinja_template