Templates: Migrate to class

Having many utility functions for initialization doesn't make much sense.
Instead, handle anything regarding template creation inside the
class which reduces the amount of function imports.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-04-21 23:28:14 -04:00
parent 9f93505bc1
commit cab789e685
4 changed files with 122 additions and 132 deletions

View File

@@ -34,8 +34,6 @@ from common.templating import (
PromptTemplate, PromptTemplate,
TemplateLoadError, TemplateLoadError,
find_template_from_model, find_template_from_model,
get_template_from_model_json,
get_template_from_file,
) )
from common.transformers_utils import GenerationConfig from common.transformers_utils import GenerationConfig
from common.utils import coalesce, unwrap from common.utils import coalesce, unwrap
@@ -276,18 +274,18 @@ class ExllamaV2Container:
logger.info("Attempting to load a prompt template if present.") logger.info("Attempting to load a prompt template if present.")
find_template_functions = [ find_template_functions = [
lambda: get_template_from_model_json( lambda: PromptTemplate.from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json", pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template", "chat_template",
), ),
lambda: get_template_from_file(find_template_from_model(model_directory)), lambda: PromptTemplate.from_file(find_template_from_model(model_directory)),
] ]
# Add lookup from prompt template name if provided # Add lookup from prompt template name if provided
if prompt_template_name: if prompt_template_name:
find_template_functions[:0] = [ find_template_functions[:0] = [
lambda: get_template_from_file(prompt_template_name), lambda: PromptTemplate.from_file(prompt_template_name),
lambda: get_template_from_model_json( lambda: PromptTemplate.from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json", pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template", "chat_template",
prompt_template_name, prompt_template_name,

View File

@@ -2,83 +2,143 @@
import json import json
import pathlib import pathlib
from functools import lru_cache
from importlib.metadata import version as package_version from importlib.metadata import version as package_version
from typing import Optional from typing import Optional
from jinja2 import Template, TemplateError from jinja2 import Template, TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger from loguru import logger
from packaging import version from packaging import version
from pydantic import BaseModel
from common.utils import unwrap from common.utils import unwrap
class PromptTemplate(BaseModel):
"""A template for chat completion prompts."""
name: str
template: str
class TemplateLoadError(Exception): class TemplateLoadError(Exception):
"""Raised on prompt template load""" """Raised on prompt template load"""
pass pass
def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict): class PromptTemplate:
"""Get a prompt from a template and a list of messages.""" """A template for chat completion prompts."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 "
f"or greater. Current version: {package_version('jinja2')}\n"
"Please upgrade jinja by running the following command: "
"pip install --upgrade jinja2"
)
compiled_template = _compile_template(prompt_template.template) name: str
rendered_template = compiled_template.render(**template_vars) raw_template: str
template_stop_strings = _get_template_stop_strings(compiled_template, template_vars) template: Template
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True
)
return rendered_template, template_stop_strings def stop_strings(self, template_vars: dict):
"""Appends extra stop strings if present in a chat template."""
extra_stop_strings = []
template_module = self.template.make_module(template_vars)
# Inspired from if hasattr(template_module, "stop_strings"):
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 if isinstance(template_module.stop_strings, list):
# TODO: Migrate to compile when template is loaded (removes the need for an lru_cache) extra_stop_strings += template_module.stop_strings
@lru_cache else:
def _compile_template(template: str): logger.warning(
"""Compiles a Jinja2 template""" "Skipping append of stopping strings from chat template "
"because stop_strings isn't a list."
)
# Exception handler return extra_stop_strings
def raise_exception(message):
raise TemplateError(message)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) def render(self, template_vars: dict):
jinja_env.globals["raise_exception"] = raise_exception """Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
jinja_template = jinja_env.from_string(template) raise ImportError(
return jinja_template "Parsing these chat completion messages requires jinja2 3.0.0 "
f"or greater. Current version: {package_version('jinja2')}\n"
"Please upgrade jinja by running the following command: "
# TODO: Migrate to run during template load "pip install --upgrade jinja2"
def _get_template_stop_strings(prompt_template: Template, template_vars: dict):
"""Appends extra stop strings if present in a chat template."""
extra_stop_strings = []
template_module = prompt_template.make_module(template_vars)
if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list):
extra_stop_strings += template_module.stop_strings
else:
logger.warning(
"Skipping append of stopping strings from chat template "
"because stop_strings isn't a list."
) )
return extra_stop_strings rendered_template = self.template.render(**template_vars)
template_stop_strings = self.stop_strings(template_vars)
return rendered_template, template_stop_strings
def compile(self, template_str: str):
"""Compiles and stores a jinja2 template"""
# Exception handler
def raise_exception(message):
raise TemplateError(message)
self.environment.globals["raise_exception"] = raise_exception
return self.environment.from_string(template_str)
def __init__(self, name: str, raw_template: str):
"""Initializer for the PromptTemplate class."""
self.name = name
self.raw_template = raw_template
self.template = self.compile(raw_template)
@classmethod
def from_file(self, prompt_template_name: str):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template_stream:
return PromptTemplate(
name=prompt_template_name,
raw_template=raw_template_stream.read(),
)
else:
# Let the user know if the template file isn't found
raise TemplateLoadError(
f'Chat template "{prompt_template_name}" not found in files.'
)
@classmethod
def from_model_json(
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if not chat_template:
raise TemplateLoadError(
"Could not find a value from chat_template key in the passed JSON. "
"Check the tokenizer config?"
)
if isinstance(chat_template, list):
# Handles the new list style of chat templates
if name:
wrapped_template = next(
(x for x in chat_template if x.get("name") == name),
{},
)
else:
wrapped_template = chat_template[0]
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
selected_template = wrapped_template.get("template")
if selected_template:
return PromptTemplate(name=name, raw_template=selected_template)
else:
raise TemplateLoadError(
f'Chat template with name "{name}" not found '
"in model templates list."
)
else:
# Can safely assume the chat template is the old style
return PromptTemplate(
name="from_tokenizer_config",
raw_template=chat_template,
)
def get_all_templates(): def get_all_templates():
@@ -101,63 +161,3 @@ def find_template_from_model(model_path: pathlib.Path):
return template_name return template_name
else: else:
raise TemplateLoadError("Could not find template from model name.") raise TemplateLoadError("Could not find template from model name.")
def get_template_from_file(prompt_template_name: str):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template:
return PromptTemplate(
name=prompt_template_name, template=raw_template.read()
)
else:
# Let the user know if the template file isn't found
raise TemplateLoadError(
f'Chat template "{prompt_template_name}" not found in files.'
)
# Get a template from a JSON file
# Requires a key and template name
def get_template_from_model_json(
json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if not chat_template:
raise TemplateLoadError(
"Could not find a value from chat_template key in the passed JSON. "
"Check the tokenizer config?"
)
if isinstance(chat_template, list):
# Handles the new list style of chat templates
if name:
wrapped_template = next(
(x for x in chat_template if x.get("name") == name),
{},
)
else:
wrapped_template = chat_template[0]
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
selected_template = wrapped_template.get("template")
if selected_template:
return PromptTemplate(name=name, template=selected_template)
else:
raise TemplateLoadError(
f'Chat template with name "{name}" not found '
"in model templates list."
)
else:
# Can safely assume the chat template is the old style
return PromptTemplate(name="from_tokenizer_config", template=chat_template)

View File

@@ -14,11 +14,7 @@ from common.concurrency import (
generate_with_semaphore, generate_with_semaphore,
) )
from common.networking import handle_request_error, run_with_request_disconnect from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import ( from common.templating import PromptTemplate, get_all_templates
get_all_templates,
get_prompt_from_template,
get_template_from_file,
)
from common.utils import coalesce, unwrap from common.utils import coalesce, unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest from endpoints.OAI.types.completion import CompletionRequest
@@ -224,8 +220,7 @@ async def switch_template(data: TemplateSwitchRequest):
raise HTTPException(400, error_message) raise HTTPException(400, error_message)
try: try:
template = get_template_from_file(data.name) model.container.prompt_template = PromptTemplate.from_file(data.name)
model.container.prompt_template = template
except FileNotFoundError as e: except FileNotFoundError as e:
error_message = handle_request_error( error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?", f"The template name {data.name} doesn't exist. Check the spelling?",
@@ -402,9 +397,7 @@ async def encode_tokens(data: TokenEncodeRequest):
**special_tokens_dict, **special_tokens_dict,
} }
text, _ = get_prompt_from_template( text, _ = model.container.prompt_template.render(template_vars)
model.container.prompt_template, template_vars
)
raw_tokens = model.container.encode_tokens(text, **data.get_params()) raw_tokens = model.container.encode_tokens(text, **data.get_params())
tokens = unwrap(raw_tokens, []) tokens = unwrap(raw_tokens, [])

View File

@@ -15,7 +15,6 @@ from common.networking import (
handle_request_disconnect, handle_request_disconnect,
handle_request_error, handle_request_error,
) )
from common.templating import get_prompt_from_template
from common.utils import unwrap from common.utils import unwrap
from endpoints.OAI.types.chat_completion import ( from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs, ChatCompletionLogprobs,
@@ -150,8 +149,8 @@ def format_prompt_with_template(data: ChatCompletionRequest):
} }
) )
prompt, template_stop_strings = get_prompt_from_template( prompt, template_stop_strings = model.container.prompt_template.render(
model.container.prompt_template, data.template_vars data.template_vars
) )
# Append template stop strings # Append template stop strings