mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
HuggingFace updated transformers to provide templates in a list for tokenizers. Update to support this new format. Providing the name of a template for the "prompt_template" value in config.yml will also look inside the template list. In addition, log if there's a template exception, but continue model loading since it shouldn't shut down the application. Signed-off-by: kingbri <bdashore3@proton.me>
164 lines
5.6 KiB
Python
164 lines
5.6 KiB
Python
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
|
|
|
import json
|
|
import pathlib
|
|
from functools import lru_cache
|
|
from importlib.metadata import version as package_version
|
|
from typing import Optional
|
|
from jinja2 import Template, TemplateError
|
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
from loguru import logger
|
|
from packaging import version
|
|
from pydantic import BaseModel
|
|
|
|
from common.utils import unwrap
|
|
|
|
|
|
class PromptTemplate(BaseModel):
|
|
"""A template for chat completion prompts."""
|
|
|
|
name: str
|
|
template: str
|
|
|
|
|
|
class TemplateLoadError(Exception):
|
|
"""Raised on prompt template load"""
|
|
|
|
pass
|
|
|
|
|
|
def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict):
|
|
"""Get a prompt from a template and a list of messages."""
|
|
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
|
raise ImportError(
|
|
"Parsing these chat completion messages requires jinja2 3.0.0 "
|
|
f"or greater. Current version: {package_version('jinja2')}\n"
|
|
"Please upgrade jinja by running the following command: "
|
|
"pip install --upgrade jinja2"
|
|
)
|
|
|
|
compiled_template = _compile_template(prompt_template.template)
|
|
rendered_template = compiled_template.render(**template_vars)
|
|
template_stop_strings = _get_template_stop_strings(compiled_template, template_vars)
|
|
|
|
return rendered_template, template_stop_strings
|
|
|
|
|
|
# Inspired from
|
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
|
# TODO: Migrate to compile when template is loaded (removes the need for an lru_cache)
|
|
@lru_cache
|
|
def _compile_template(template: str):
|
|
"""Compiles a Jinja2 template"""
|
|
|
|
# Exception handler
|
|
def raise_exception(message):
|
|
raise TemplateError(message)
|
|
|
|
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
|
jinja_env.globals["raise_exception"] = raise_exception
|
|
|
|
jinja_template = jinja_env.from_string(template)
|
|
return jinja_template
|
|
|
|
|
|
# TODO: Migrate to run during template load
|
|
def _get_template_stop_strings(prompt_template: Template, template_vars: dict):
|
|
"""Appends extra stop strings if present in a chat template."""
|
|
|
|
extra_stop_strings = []
|
|
template_module = prompt_template.make_module(template_vars)
|
|
|
|
if hasattr(template_module, "stop_strings"):
|
|
if isinstance(template_module.stop_strings, list):
|
|
extra_stop_strings += template_module.stop_strings
|
|
else:
|
|
logger.warning(
|
|
"Skipping append of stopping strings from chat template "
|
|
"because stop_strings isn't a list."
|
|
)
|
|
|
|
return extra_stop_strings
|
|
|
|
|
|
def get_all_templates():
|
|
"""Fetches all templates from the templates directory"""
|
|
|
|
template_directory = pathlib.Path("templates")
|
|
return template_directory.glob("*.jinja")
|
|
|
|
|
|
def find_template_from_model(model_path: pathlib.Path):
|
|
"""Find a matching template name from a model path."""
|
|
model_name = model_path.name
|
|
template_files = get_all_templates()
|
|
|
|
for filepath in template_files:
|
|
template_name = filepath.stem.lower()
|
|
|
|
# Check if the template name is present in the model name
|
|
if template_name in model_name.lower():
|
|
return template_name
|
|
else:
|
|
raise TemplateLoadError("Could not find template from model name.")
|
|
|
|
|
|
def get_template_from_file(prompt_template_name: str):
|
|
"""Get a template from a jinja file."""
|
|
|
|
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
|
if template_path.exists():
|
|
with open(template_path, "r", encoding="utf8") as raw_template:
|
|
return PromptTemplate(
|
|
name=prompt_template_name, template=raw_template.read()
|
|
)
|
|
else:
|
|
# Let the user know if the template file isn't found
|
|
raise TemplateLoadError(
|
|
f'Chat template "{prompt_template_name}" not found in files.'
|
|
)
|
|
|
|
|
|
# Get a template from a JSON file
|
|
# Requires a key and template name
|
|
def get_template_from_model_json(
|
|
json_path: pathlib.Path, key: str, name: Optional[str] = None
|
|
):
|
|
"""Get a template from a JSON file. Requires a key and template name"""
|
|
if not json_path.exists():
|
|
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
|
|
|
with open(json_path, "r", encoding="utf8") as config_file:
|
|
model_config = json.load(config_file)
|
|
chat_template = model_config.get(key)
|
|
|
|
if not chat_template:
|
|
raise TemplateLoadError(
|
|
"Could not find a value from chat_template key in the passed JSON. "
|
|
"Check the tokenizer config?"
|
|
)
|
|
|
|
if isinstance(chat_template, list):
|
|
# Handles the new list style of chat templates
|
|
if name:
|
|
wrapped_template = next(
|
|
(x for x in chat_template if x.get("name") == name),
|
|
{},
|
|
)
|
|
else:
|
|
wrapped_template = chat_template[0]
|
|
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
|
|
|
|
selected_template = wrapped_template.get("template")
|
|
|
|
if selected_template:
|
|
return PromptTemplate(name=name, template=selected_template)
|
|
else:
|
|
raise TemplateLoadError(
|
|
f'Chat template with name "{name}" not found '
|
|
"in model templates list."
|
|
)
|
|
else:
|
|
# Can safely assume the chat template is the old style
|
|
return PromptTemplate(name="from_tokenizer_config", template=chat_template)
|