From b4752c1e62c8daa27246f3e0583eeb3e7c4ad27f Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 17 Aug 2024 11:30:50 -0400 Subject: [PATCH] Templates: Revert to load metadata on runtime Metadata is generated via a template's module. This requires a single iteration through the template. If a template tries to access a passed variable that doesn't exist, it will error. Therefore, generate the metadata at runtime to prevent these errors from happening. To optimize further, cache the metadata after the first generation to prevent the expensive call of making a template module. Signed-off-by: kingbri --- common/templating.py | 21 ++++++++++++++++----- endpoints/OAI/utils/chat_completion.py | 4 +++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/common/templating.py b/common/templating.py index 47ce7e8..021d1d4 100644 --- a/common/templating.py +++ b/common/templating.py @@ -1,5 +1,6 @@ """Small replication of AutoTokenizer's chat template system for efficiency""" +from functools import lru_cache import json import pathlib from importlib.metadata import version as package_version @@ -34,14 +35,24 @@ class PromptTemplate: environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True ) - metadata: TemplateMetadata + metadata: Optional[TemplateMetadata] = None - def extract_metadata(self): - """Returns deserialized template metadata from a chat template.""" + def extract_metadata(self, template_vars: dict): + """ + Returns deserialized template metadata from a chat template. + + NOTE: Requires all template vars to be passed in since the template + is run once to make a module and errors can result. + """ + + # No need to extract new metadata if it already exists + # This might be removed if stored metadata becomes arbitrary + if self.metadata: + return self.metadata template_metadata = TemplateMetadata() - template_module = self.template.make_module() + template_module = self.template.make_module(template_vars) if hasattr(template_module, "stop_strings"): if isinstance(template_module.stop_strings, list): @@ -60,6 +71,7 @@ class PromptTemplate: if isinstance(template_module.tool_start_token, int): template_metadata.tool_starts.append(template_module.tool_start_token) + self.metadata = template_metadata return template_metadata def render(self, template_vars: dict): @@ -93,7 +105,6 @@ class PromptTemplate: self.name = name self.raw_template = raw_template self.template = self.compile(raw_template) - self.metadata = self.extract_metadata() @classmethod def from_file(self, prompt_template_name: str): diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index e15e820..d924b5e 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -181,7 +181,9 @@ def _create_stream_chunk( def _append_template_metadata(data: ChatCompletionRequest): """Adding metadata is a one-time process.""" - template_metadata = model.container.prompt_template.metadata + template_metadata = model.container.prompt_template.extract_metadata( + data.template_vars + ) # Stop strings if isinstance(data.stop, str):