Files
tabbyAPI/common/templating.py
turboderp 179479199b Rework tool calls and OAI chat completions
- move tool config from template_vars to separate yml config
- new per-gen stream collector used for both streaming and non-streaming requests to ensure logic is consistent for both
- move responsibility for switching between phases to stream collector
- collect tool calls during streaming and parse at the end of each gen
- prevent streaming empty content spans (be nice to clients)
- correctly aggregate usage stats for n>1 requests, always emit with last chunk in last gen to finish
- collect logprobs in model wrapper and correctly handle logprobs for multi-token chars etc.
- respect top_logprobs argument in request
- handle a number of edge cases like <think> tag being part of held string, etc.
- retain tool parsing and inference-abort fixes from #413, apply similar fix to non-stream request as well

Still TODO:
- testing and validation with more models and tool schemas (tested on Qwen so far)
- enable JSON constraint for JSON tool models
- possibly some pydantification
- documentation
2026-03-30 00:22:55 +02:00

272 lines
9.6 KiB
Python

"""Small replication of AutoTokenizer's chat template system for efficiency"""
import traceback
import aiofiles
import json
import pathlib
from ruamel.yaml import YAML
from datetime import datetime
from importlib.metadata import version as package_version
from typing import Optional
from jinja2 import Template, TemplateError
from jinja2.ext import loopcontrols
from jinja2.sandbox import ImmutableSandboxedEnvironment
from common.logger import xlogger
from common.config_models import ToolConfig
from markupsafe import Markup
from packaging import version
from common.utils import unwrap
class TemplateLoadError(Exception):
"""Raised on prompt template load"""
pass
VALID_TOOL_CALL_FORMATS = {"json", "xml", "auto"}
class PromptTemplate:
"""A template for chat completion prompts."""
name: str
raw_template: str
template: Template
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
enable_async=True,
extensions=[loopcontrols],
)
@staticmethod
def _tojson_compat(value, indent=None, ensure_ascii=True):
"""Compatibility JSON filter for chat templates.
Some model templates call ``tojson(ensure_ascii=False)`` while the
bundled Jinja filter may not accept that keyword in sandboxed mode.
"""
return Markup(
json.dumps(
value,
indent=indent,
ensure_ascii=ensure_ascii,
separators=(",", ": "),
)
)
async def render(self, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 "
f"or greater. Current version: {package_version('jinja2')}\n"
"Please update jinja by running the following command: "
"pip install --upgrade jinja2"
)
rendered_template = await self.template.render_async(**template_vars)
return rendered_template
def compile(self, template_str: str):
"""Compiles and stores a jinja2 template"""
# Some models require strftime_now, e.g. Granite3
def strftime_now(format):
current_time = datetime.now()
return current_time.strftime(format)
# Exception handler
def raise_exception(message):
raise TemplateError(message)
self.environment.globals["strftime_now"] = strftime_now
self.environment.globals["raise_exception"] = raise_exception
self.environment.filters["tojson"] = self._tojson_compat
return self.environment.from_string(template_str)
def __init__(self, name: str, raw_template: str):
"""Initializer for the PromptTemplate class."""
self.name = name
self.raw_template = raw_template
self.template = self.compile(raw_template)
@classmethod
async def from_file(cls, template_path: pathlib.Path):
"""Get a template from a jinja file."""
# Add the jinja extension if it isn't provided
if template_path.suffix.endswith(".jinja"):
template_name = template_path.name.split(".jinja")[0]
else:
template_name = template_path.name
template_path = template_path.with_suffix(".jinja")
if template_path.exists():
async with aiofiles.open(template_path, "r", encoding="utf8") as raw_template_stream:
contents = await raw_template_stream.read()
return cls(
name=template_name,
raw_template=contents,
)
else:
# Let the user know if the template file isn't found
raise TemplateLoadError(f'Chat template "{template_name}" not found in files.')
@classmethod
async def from_model_json(cls, json_path: pathlib.Path, key: str, name: Optional[str] = None):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
contents = await config_file.read()
model_config = json.loads(contents)
chat_template = model_config.get(key)
if not chat_template:
raise TemplateLoadError(
"Could not find a value from chat_template key in the passed JSON. "
"Check the tokenizer config?"
)
if isinstance(chat_template, list):
# Handles the new list style of chat templates
if name:
wrapped_template = next(
(x for x in chat_template if x.get("name") == name),
{},
)
else:
wrapped_template = chat_template[0]
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
selected_template = wrapped_template.get("template")
if selected_template:
return PromptTemplate(name=name, raw_template=selected_template)
else:
raise TemplateLoadError(
f'Chat template with name "{name}" not found in model templates list.'
)
else:
# Can safely assume the chat template is the old style
return cls(
name="from_tokenizer_config",
raw_template=chat_template,
)
def get_all_templates():
"""Fetches all templates from the templates directory"""
template_directory = pathlib.Path("templates")
return template_directory.glob("*.jinja")
def find_template_from_model(model_path: pathlib.Path):
"""Find a matching template name from a model path."""
model_name = model_path.name
template_files = get_all_templates()
for filepath in template_files:
template_name = filepath.stem.lower()
# Check if the template name is present in the model name
if template_name in model_name.lower():
return template_name
else:
raise TemplateLoadError("Could not find template from model name.")
async def find_prompt_template(template_name, model_dir: pathlib.Path):
"""Tries to find a prompt template using various methods."""
xlogger.info("Attempting to load a prompt template if present.")
find_template_functions = [
lambda: PromptTemplate.from_file(model_dir / "chat_template.jinja"),
lambda: PromptTemplate.from_model_json(
model_dir / "chat_template.json",
key="chat_template",
),
lambda: PromptTemplate.from_model_json(
model_dir / "tokenizer_config.json",
key="chat_template",
),
lambda: PromptTemplate.from_file(find_template_from_model(model_dir)),
]
# Find the template in the model directory if it exists
model_dir_template_path = model_dir / "tabby_template.jinja"
if model_dir_template_path.exists():
find_template_functions[:0] = [lambda: PromptTemplate.from_file(model_dir_template_path)]
# Add lookup from prompt template name if provided
# TODO: Possibly link to the TokenizerConfig class
if template_name:
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(pathlib.Path("templates") / template_name),
lambda: PromptTemplate.from_model_json(
model_dir / "tokenizer_config.json",
key="chat_template",
name=template_name,
),
]
# Continue on exception since functions are tried as they fail
for template_func in find_template_functions:
try:
prompt_template = await template_func()
if prompt_template is not None:
return prompt_template
except TemplateLoadError as e:
xlogger.warning("TemplateLoadError", {"exception": str(e)}, details=f"{str(e)}")
continue
except Exception:
xlogger.error(traceback.format_exc())
xlogger.warning(
"An unexpected error happened when trying to load the template. "
"Trying other methods."
)
continue
return None
def get_all_tool_formats():
"""Fetches all tool formats from the tool_formats directory"""
tool_formats_directory = pathlib.Path("tool_formats")
return tool_formats_directory.glob("*.yaml")
async def tool_config_from_file(tool_format_name: Optional[str]):
"""Fetches a tool config from a file"""
if not tool_format_name:
return ToolConfig.model_validate({})
preset_path = pathlib.Path(f"tool_formats/{tool_format_name}.yml")
if preset_path.exists():
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
contents = await raw_preset.read()
# Create a temporary YAML parser
yaml = YAML(typ="safe")
cfg = yaml.load(contents)
xlogger.info(f"Loaded tool config {tool_format_name}", {"tool_config": cfg})
return ToolConfig.model_validate(cfg)
else:
xlogger.error(
f'Tool format name "{tool_format_name}" was not found. '
+ "Make sure it's located in the tool_formats folder."
)
return ToolConfig.model_validate({})