mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
Templates: Fix stop_string parsing
Template modules grab all set vars, including ones that use runtime vars. If a template var is set to a runtime var and a module is created, an UndefinedError fires. Use make_module instead to pass runtime vars when creating a template module. Resolves #92 Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -9,7 +9,6 @@ from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
class PromptTemplate(BaseModel):
|
||||
@@ -19,12 +18,7 @@ class PromptTemplate(BaseModel):
|
||||
template: str
|
||||
|
||||
|
||||
def get_prompt_from_template(
|
||||
messages,
|
||||
prompt_template: PromptTemplate,
|
||||
add_generation_prompt: bool,
|
||||
special_tokens: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict):
|
||||
"""Get a prompt from a template and a list of messages."""
|
||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||
raise ImportError(
|
||||
@@ -35,12 +29,8 @@ def get_prompt_from_template(
|
||||
)
|
||||
|
||||
compiled_template = _compile_template(prompt_template.template)
|
||||
rendered_template = compiled_template.render(
|
||||
messages=messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**special_tokens,
|
||||
)
|
||||
template_stop_strings = _get_template_stop_strings(compiled_template)
|
||||
rendered_template = compiled_template.render(**template_vars)
|
||||
template_stop_strings = _get_template_stop_strings(compiled_template, template_vars)
|
||||
|
||||
return rendered_template, template_stop_strings
|
||||
|
||||
@@ -64,14 +54,15 @@ def _compile_template(template: str):
|
||||
|
||||
|
||||
# TODO: Migrate to run during template load
|
||||
def _get_template_stop_strings(prompt_template: Template):
|
||||
def _get_template_stop_strings(prompt_template: Template, template_vars: dict):
|
||||
"""Appends extra stop strings if present in a chat template."""
|
||||
|
||||
extra_stop_strings = []
|
||||
template_module = prompt_template.make_module(template_vars)
|
||||
|
||||
if hasattr(prompt_template.module, "stop_strings"):
|
||||
if isinstance(prompt_template.module.stop_strings, list):
|
||||
extra_stop_strings += prompt_template.module.stop_strings
|
||||
if hasattr(template_module, "stop_strings"):
|
||||
if isinstance(template_module.stop_strings, list):
|
||||
extra_stop_strings += template_module.stop_strings
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping append of stopping strings from chat template "
|
||||
|
||||
@@ -141,11 +141,14 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
||||
unwrap(data.ban_eos_token, False),
|
||||
)
|
||||
|
||||
template_vars = {
|
||||
"messages": data.messages,
|
||||
"add_generation_prompt": data.add_generation_prompt,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
prompt, template_stop_strings = get_prompt_from_template(
|
||||
data.messages,
|
||||
model.container.prompt_template,
|
||||
data.add_generation_prompt,
|
||||
special_tokens_dict,
|
||||
model.container.prompt_template, template_vars
|
||||
)
|
||||
|
||||
# Append template stop strings
|
||||
@@ -157,17 +160,17 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
||||
return prompt
|
||||
|
||||
except KeyError as exc:
|
||||
raise HTTPException(
|
||||
400,
|
||||
error_message = handle_request_error(
|
||||
"Could not find a Conversation from prompt template "
|
||||
f"'{model.container.prompt_template.name}'. "
|
||||
"Check your spelling?",
|
||||
) from exc
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
except TemplateError as exc:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"TemplateError: {str(exc)}",
|
||||
) from exc
|
||||
error_message = handle_request_error(f"TemplateError: {str(exc)}").error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
|
||||
Reference in New Issue
Block a user