diff --git a/common/templating.py b/common/templating.py index ed8af06..ec03dc1 100644 --- a/common/templating.py +++ b/common/templating.py @@ -4,8 +4,9 @@ import json import pathlib from functools import lru_cache from importlib.metadata import version as package_version -from jinja2 import TemplateError +from jinja2 import Template, TemplateError from jinja2.sandbox import ImmutableSandboxedEnvironment +from loguru import logger from packaging import version from pydantic import BaseModel from typing import Optional, Dict @@ -34,15 +35,19 @@ def get_prompt_from_template( ) compiled_template = _compile_template(prompt_template.template) - return compiled_template.render( + rendered_template = compiled_template.render( messages=messages, add_generation_prompt=add_generation_prompt, **special_tokens, ) + template_stop_strings = _get_template_stop_strings(compiled_template) + + return rendered_template, template_stop_strings # Inspired from # https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 +# TODO: Migrate to compile when template is loaded (removes the need for an lru_cache) @lru_cache def _compile_template(template: str): """Compiles a Jinja2 template""" @@ -58,6 +63,24 @@ def _compile_template(template: str): return jinja_template +# TODO: Migrate to run during template load +def _get_template_stop_strings(prompt_template: Template): + """Appends extra stop strings if present in a chat template.""" + + extra_stop_strings = [] + + if hasattr(prompt_template.module, "stop_strings"): + if isinstance(prompt_template.module.stop_strings, list): + extra_stop_strings += prompt_template.module.stop_strings + else: + logger.warning( + "Skipping append of stopping strings from chat template " + "because stop_strings isn't a list." + ) + + return extra_stop_strings + + def get_all_templates(): """Fetches all templates from the templates directory""" diff --git a/endpoints/OAI/app.py b/endpoints/OAI/app.py index 6c523d2..d6690ee 100644 --- a/endpoints/OAI/app.py +++ b/endpoints/OAI/app.py @@ -508,7 +508,10 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) if isinstance(data.messages, str): prompt = data.messages else: - prompt = format_prompt_with_template(data) + # Compile the prompt and get any additional stop strings from the template + # Template stop strings can be overriden by sampler overrides if force is true + prompt, template_stop_strings = format_prompt_with_template(data) + data.stop += template_stop_strings disable_request_streaming = unwrap( config.developer_config().get("disable_request_streaming"), False