Templates: Support bos_token and eos_token fields

These are commonly seen in huggingface provided chat templates and
aren't that difficult to add in.

For feature parity, honor the add_bos_token and ban_eos_token
parameters when constructing the prompt.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-22 10:31:50 -05:00
parent 2bf8087de3
commit a14abfe21c
3 changed files with 21 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ from importlib.metadata import version as package_version
from jinja2.sandbox import ImmutableSandboxedEnvironment
from packaging import version
from pydantic import BaseModel
from typing import Optional, Dict
# Small replication of AutoTokenizer's chat template system for efficiency
@@ -12,7 +13,10 @@ class PromptTemplate(BaseModel):
name: str
template: str
def get_prompt_from_template(messages, prompt_template: PromptTemplate, add_generation_prompt: bool):
def get_prompt_from_template(messages,
prompt_template: PromptTemplate,
add_generation_prompt: bool,
special_tokens: Optional[Dict[str, str]] = None):
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 or greater. "
@@ -24,7 +28,8 @@ def get_prompt_from_template(messages, prompt_template: PromptTemplate, add_gene
compiled_template = _compile_template(prompt_template.template)
return compiled_template.render(
messages = messages,
add_generation_prompt = add_generation_prompt
add_generation_prompt = add_generation_prompt,
**special_tokens,
)
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761