mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Tree: Refactor code organization
Move common functions into their own folder and refactor the backends to use their own folder as well. Also cleanup imports and alphabetize import statments themselves. Finally, move colab and docker into their own folders as well. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
104
common/templating.py
Normal file
104
common/templating.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||
import json
|
||||
import pathlib
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import version as package_version
|
||||
from jinja2 import TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
class PromptTemplate(BaseModel):
|
||||
"""A template for chat completion prompts."""
|
||||
|
||||
name: str
|
||||
template: str
|
||||
|
||||
|
||||
def get_prompt_from_template(
|
||||
messages,
|
||||
prompt_template: PromptTemplate,
|
||||
add_generation_prompt: bool,
|
||||
special_tokens: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Get a prompt from a template and a list of messages."""
|
||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||
raise ImportError(
|
||||
"Parsing these chat completion messages requires jinja2 3.0.0 "
|
||||
f"or greater. Current version: {package_version('jinja2')}\n"
|
||||
"Please upgrade jinja by running the following command: "
|
||||
"pip install --upgrade jinja2"
|
||||
)
|
||||
|
||||
compiled_template = _compile_template(prompt_template.template)
|
||||
return compiled_template.render(
|
||||
messages=messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**special_tokens,
|
||||
)
|
||||
|
||||
|
||||
# Inspired from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
||||
@lru_cache
|
||||
def _compile_template(template: str):
|
||||
"""Compiles a Jinja2 template"""
|
||||
|
||||
# Exception handler
|
||||
def raise_exception(message):
|
||||
raise TemplateError(message)
|
||||
|
||||
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
||||
jinja_env.globals["raise_exception"] = raise_exception
|
||||
|
||||
jinja_template = jinja_env.from_string(template)
|
||||
return jinja_template
|
||||
|
||||
|
||||
def get_all_templates():
|
||||
"""Fetches all templates from the templates directory"""
|
||||
|
||||
template_directory = pathlib.Path("templates")
|
||||
return template_directory.glob("*.jinja")
|
||||
|
||||
|
||||
def find_template_from_model(model_path: pathlib.Path):
|
||||
"""Find a matching template name from a model path."""
|
||||
model_name = model_path.name
|
||||
template_files = get_all_templates()
|
||||
for filepath in template_files:
|
||||
template_name = filepath.stem.lower()
|
||||
|
||||
# Check if the template name is present in the model name
|
||||
if template_name in model_name.lower():
|
||||
return template_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_template_from_file(prompt_template_name: str):
|
||||
"""Get a template from a jinja file."""
|
||||
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
||||
if template_path.exists():
|
||||
with open(template_path, "r", encoding="utf8") as raw_template:
|
||||
return PromptTemplate(
|
||||
name=prompt_template_name, template=raw_template.read()
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Get a template from a JSON file
|
||||
# Requires a key and template name
|
||||
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if json_path.exists():
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
chat_template = model_config.get(key)
|
||||
if chat_template:
|
||||
return PromptTemplate(name=name, template=chat_template)
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user