From cab789e685e6d3d7dc0d9271054cfc6d584e27c1 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 21 Apr 2024 23:28:14 -0400 Subject: [PATCH] Templates: Migrate to class Having many utility functions for initialization doesn't make much sense. Instead, handle anything regarding template creation inside the class which reduces the amount of function imports. Signed-off-by: kingbri --- backends/exllamav2/model.py | 10 +- common/templating.py | 226 ++++++++++++------------- endpoints/OAI/router.py | 13 +- endpoints/OAI/utils/chat_completion.py | 5 +- 4 files changed, 122 insertions(+), 132 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index d6b86ea..36754f8 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -34,8 +34,6 @@ from common.templating import ( PromptTemplate, TemplateLoadError, find_template_from_model, - get_template_from_model_json, - get_template_from_file, ) from common.transformers_utils import GenerationConfig from common.utils import coalesce, unwrap @@ -276,18 +274,18 @@ class ExllamaV2Container: logger.info("Attempting to load a prompt template if present.") find_template_functions = [ - lambda: get_template_from_model_json( + lambda: PromptTemplate.from_model_json( pathlib.Path(self.config.model_dir) / "tokenizer_config.json", "chat_template", ), - lambda: get_template_from_file(find_template_from_model(model_directory)), + lambda: PromptTemplate.from_file(find_template_from_model(model_directory)), ] # Add lookup from prompt template name if provided if prompt_template_name: find_template_functions[:0] = [ - lambda: get_template_from_file(prompt_template_name), - lambda: get_template_from_model_json( + lambda: PromptTemplate.from_file(prompt_template_name), + lambda: PromptTemplate.from_model_json( pathlib.Path(self.config.model_dir) / "tokenizer_config.json", "chat_template", prompt_template_name, diff --git a/common/templating.py b/common/templating.py index 6bd2a88..f742386 100644 --- a/common/templating.py +++ b/common/templating.py @@ -2,83 +2,143 @@ import json import pathlib -from functools import lru_cache from importlib.metadata import version as package_version from typing import Optional from jinja2 import Template, TemplateError from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger from packaging import version -from pydantic import BaseModel from common.utils import unwrap -class PromptTemplate(BaseModel): - """A template for chat completion prompts.""" - - name: str - template: str - - class TemplateLoadError(Exception): """Raised on prompt template load""" pass -def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict): - """Get a prompt from a template and a list of messages.""" - if version.parse(package_version("jinja2")) < version.parse("3.0.0"): - raise ImportError( - "Parsing these chat completion messages requires jinja2 3.0.0 " - f"or greater. Current version: {package_version('jinja2')}\n" - "Please upgrade jinja by running the following command: " - "pip install --upgrade jinja2" - ) +class PromptTemplate: + """A template for chat completion prompts.""" - compiled_template = _compile_template(prompt_template.template) - rendered_template = compiled_template.render(**template_vars) - template_stop_strings = _get_template_stop_strings(compiled_template, template_vars) + name: str + raw_template: str + template: Template + environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment( + trim_blocks=True, lstrip_blocks=True + ) - return rendered_template, template_stop_strings + def stop_strings(self, template_vars: dict): + """Appends extra stop strings if present in a chat template.""" + extra_stop_strings = [] + template_module = self.template.make_module(template_vars) -# Inspired from -# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 -# TODO: Migrate to compile when template is loaded (removes the need for an lru_cache) -@lru_cache -def _compile_template(template: str): - """Compiles a Jinja2 template""" + if hasattr(template_module, "stop_strings"): + if isinstance(template_module.stop_strings, list): + extra_stop_strings += template_module.stop_strings + else: + logger.warning( + "Skipping append of stopping strings from chat template " + "because stop_strings isn't a list." + ) - # Exception handler - def raise_exception(message): - raise TemplateError(message) + return extra_stop_strings - jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) - jinja_env.globals["raise_exception"] = raise_exception - - jinja_template = jinja_env.from_string(template) - return jinja_template - - -# TODO: Migrate to run during template load -def _get_template_stop_strings(prompt_template: Template, template_vars: dict): - """Appends extra stop strings if present in a chat template.""" - - extra_stop_strings = [] - template_module = prompt_template.make_module(template_vars) - - if hasattr(template_module, "stop_strings"): - if isinstance(template_module.stop_strings, list): - extra_stop_strings += template_module.stop_strings - else: - logger.warning( - "Skipping append of stopping strings from chat template " - "because stop_strings isn't a list." + def render(self, template_vars: dict): + """Get a prompt from a template and a list of messages.""" + if version.parse(package_version("jinja2")) < version.parse("3.0.0"): + raise ImportError( + "Parsing these chat completion messages requires jinja2 3.0.0 " + f"or greater. Current version: {package_version('jinja2')}\n" + "Please upgrade jinja by running the following command: " + "pip install --upgrade jinja2" ) - return extra_stop_strings + rendered_template = self.template.render(**template_vars) + template_stop_strings = self.stop_strings(template_vars) + + return rendered_template, template_stop_strings + + def compile(self, template_str: str): + """Compiles and stores a jinja2 template""" + + # Exception handler + def raise_exception(message): + raise TemplateError(message) + + self.environment.globals["raise_exception"] = raise_exception + + return self.environment.from_string(template_str) + + def __init__(self, name: str, raw_template: str): + """Initializer for the PromptTemplate class.""" + + self.name = name + self.raw_template = raw_template + self.template = self.compile(raw_template) + + @classmethod + def from_file(self, prompt_template_name: str): + """Get a template from a jinja file.""" + + template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja") + if template_path.exists(): + with open(template_path, "r", encoding="utf8") as raw_template_stream: + return PromptTemplate( + name=prompt_template_name, + raw_template=raw_template_stream.read(), + ) + else: + # Let the user know if the template file isn't found + raise TemplateLoadError( + f'Chat template "{prompt_template_name}" not found in files.' + ) + + @classmethod + def from_model_json( + self, json_path: pathlib.Path, key: str, name: Optional[str] = None + ): + """Get a template from a JSON file. Requires a key and template name""" + if not json_path.exists(): + raise TemplateLoadError(f'Model JSON path "{json_path}" not found.') + + with open(json_path, "r", encoding="utf8") as config_file: + model_config = json.load(config_file) + chat_template = model_config.get(key) + + if not chat_template: + raise TemplateLoadError( + "Could not find a value from chat_template key in the passed JSON. " + "Check the tokenizer config?" + ) + + if isinstance(chat_template, list): + # Handles the new list style of chat templates + if name: + wrapped_template = next( + (x for x in chat_template if x.get("name") == name), + {}, + ) + else: + wrapped_template = chat_template[0] + name = unwrap(wrapped_template.get("name"), "from_tokenizer_config") + + selected_template = wrapped_template.get("template") + + if selected_template: + return PromptTemplate(name=name, raw_template=selected_template) + else: + raise TemplateLoadError( + f'Chat template with name "{name}" not found ' + "in model templates list." + ) + else: + # Can safely assume the chat template is the old style + return PromptTemplate( + name="from_tokenizer_config", + raw_template=chat_template, + ) def get_all_templates(): @@ -101,63 +161,3 @@ def find_template_from_model(model_path: pathlib.Path): return template_name else: raise TemplateLoadError("Could not find template from model name.") - - -def get_template_from_file(prompt_template_name: str): - """Get a template from a jinja file.""" - - template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja") - if template_path.exists(): - with open(template_path, "r", encoding="utf8") as raw_template: - return PromptTemplate( - name=prompt_template_name, template=raw_template.read() - ) - else: - # Let the user know if the template file isn't found - raise TemplateLoadError( - f'Chat template "{prompt_template_name}" not found in files.' - ) - - -# Get a template from a JSON file -# Requires a key and template name -def get_template_from_model_json( - json_path: pathlib.Path, key: str, name: Optional[str] = None -): - """Get a template from a JSON file. Requires a key and template name""" - if not json_path.exists(): - raise TemplateLoadError(f'Model JSON path "{json_path}" not found.') - - with open(json_path, "r", encoding="utf8") as config_file: - model_config = json.load(config_file) - chat_template = model_config.get(key) - - if not chat_template: - raise TemplateLoadError( - "Could not find a value from chat_template key in the passed JSON. " - "Check the tokenizer config?" - ) - - if isinstance(chat_template, list): - # Handles the new list style of chat templates - if name: - wrapped_template = next( - (x for x in chat_template if x.get("name") == name), - {}, - ) - else: - wrapped_template = chat_template[0] - name = unwrap(wrapped_template.get("name"), "from_tokenizer_config") - - selected_template = wrapped_template.get("template") - - if selected_template: - return PromptTemplate(name=name, template=selected_template) - else: - raise TemplateLoadError( - f'Chat template with name "{name}" not found ' - "in model templates list." - ) - else: - # Can safely assume the chat template is the old style - return PromptTemplate(name="from_tokenizer_config", template=chat_template) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index e2fec3f..5b67714 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -14,11 +14,7 @@ from common.concurrency import ( generate_with_semaphore, ) from common.networking import handle_request_error, run_with_request_disconnect -from common.templating import ( - get_all_templates, - get_prompt_from_template, - get_template_from_file, -) +from common.templating import PromptTemplate, get_all_templates from common.utils import coalesce, unwrap from endpoints.OAI.types.auth import AuthPermissionResponse from endpoints.OAI.types.completion import CompletionRequest @@ -224,8 +220,7 @@ async def switch_template(data: TemplateSwitchRequest): raise HTTPException(400, error_message) try: - template = get_template_from_file(data.name) - model.container.prompt_template = template + model.container.prompt_template = PromptTemplate.from_file(data.name) except FileNotFoundError as e: error_message = handle_request_error( f"The template name {data.name} doesn't exist. Check the spelling?", @@ -402,9 +397,7 @@ async def encode_tokens(data: TokenEncodeRequest): **special_tokens_dict, } - text, _ = get_prompt_from_template( - model.container.prompt_template, template_vars - ) + text, _ = model.container.prompt_template.render(template_vars) raw_tokens = model.container.encode_tokens(text, **data.get_params()) tokens = unwrap(raw_tokens, []) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 0ddaa94..155f806 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -15,7 +15,6 @@ from common.networking import ( handle_request_disconnect, handle_request_error, ) -from common.templating import get_prompt_from_template from common.utils import unwrap from endpoints.OAI.types.chat_completion import ( ChatCompletionLogprobs, @@ -150,8 +149,8 @@ def format_prompt_with_template(data: ChatCompletionRequest): } ) - prompt, template_stop_strings = get_prompt_from_template( - model.container.prompt_template, data.template_vars + prompt, template_stop_strings = model.container.prompt_template.render( + data.template_vars ) # Append template stop strings