From dc456f4cc21efaaeeeefd2fec492ca7afb819320 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 27 Mar 2024 22:07:43 -0400 Subject: [PATCH] Templates: Add stop_strings meta param Adding the stop_strings var to chat templates will allow for the template creator to specify stopping strings to add onto chat completions. Thes get appended with existing stopping strings that are passed in the API request. However, a sampler override with force: true will override all stopping strings. Signed-off-by: kingbri --- common/templating.py | 27 +++++++++++++++++++++++++-- endpoints/OAI/app.py | 5 ++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/common/templating.py b/common/templating.py index ed8af06..ec03dc1 100644 --- a/common/templating.py +++ b/common/templating.py @@ -4,8 +4,9 @@ import json import pathlib from functools import lru_cache from importlib.metadata import version as package_version -from jinja2 import TemplateError +from jinja2 import Template, TemplateError from jinja2.sandbox import ImmutableSandboxedEnvironment +from loguru import logger from packaging import version from pydantic import BaseModel from typing import Optional, Dict @@ -34,15 +35,19 @@ def get_prompt_from_template( ) compiled_template = _compile_template(prompt_template.template) - return compiled_template.render( + rendered_template = compiled_template.render( messages=messages, add_generation_prompt=add_generation_prompt, **special_tokens, ) + template_stop_strings = _get_template_stop_strings(compiled_template) + + return rendered_template, template_stop_strings # Inspired from # https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 +# TODO: Migrate to compile when template is loaded (removes the need for an lru_cache) @lru_cache def _compile_template(template: str): """Compiles a Jinja2 template""" @@ -58,6 +63,24 @@ def _compile_template(template: str): return jinja_template +# TODO: Migrate to run during template load +def _get_template_stop_strings(prompt_template: Template): + """Appends extra stop strings if present in a chat template.""" + + extra_stop_strings = [] + + if hasattr(prompt_template.module, "stop_strings"): + if isinstance(prompt_template.module.stop_strings, list): + extra_stop_strings += prompt_template.module.stop_strings + else: + logger.warning( + "Skipping append of stopping strings from chat template " + "because stop_strings isn't a list." + ) + + return extra_stop_strings + + def get_all_templates(): """Fetches all templates from the templates directory""" diff --git a/endpoints/OAI/app.py b/endpoints/OAI/app.py index 6c523d2..d6690ee 100644 --- a/endpoints/OAI/app.py +++ b/endpoints/OAI/app.py @@ -508,7 +508,10 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) if isinstance(data.messages, str): prompt = data.messages else: - prompt = format_prompt_with_template(data) + # Compile the prompt and get any additional stop strings from the template + # Template stop strings can be overriden by sampler overrides if force is true + prompt, template_stop_strings = format_prompt_with_template(data) + data.stop += template_stop_strings disable_request_streaming = unwrap( config.developer_config().get("disable_request_streaming"), False