Model: Move find_template function to templating

Makes sense to extract to a utility function instead.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri
2025-04-20 18:27:53 -04:00
parent 8e238fa8f6
commit 13beef8021
2 changed files with 58 additions and 62 deletions

View File

@@ -1,5 +1,6 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
import traceback
import aiofiles
import json
import pathlib
@@ -211,3 +212,58 @@ def find_template_from_model(model_path: pathlib.Path):
return template_name
else:
raise TemplateLoadError("Could not find template from model name.")
async def find_prompt_template(template_name, model_dir: pathlib.Path):
"""Tries to find a prompt template using various methods."""
logger.info("Attempting to load a prompt template if present.")
find_template_functions = [
lambda: PromptTemplate.from_model_json(
model_dir / "chat_template.json",
key="chat_template",
),
lambda: PromptTemplate.from_model_json(
model_dir / "tokenizer_config.json",
key="chat_template",
),
lambda: PromptTemplate.from_file(find_template_from_model(model_dir)),
]
# Find the template in the model directory if it exists
model_dir_template_path = model_dir / "tabby_template.jinja"
if model_dir_template_path.exists():
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(model_dir_template_path)
]
# Add lookup from prompt template name if provided
if template_name:
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(
pathlib.Path("templates") / template_name
),
lambda: PromptTemplate.from_model_json(
model_dir / "tokenizer_config.json",
key="chat_template",
name=template_name,
),
]
# Continue on exception since functions are tried as they fail
for template_func in find_template_functions:
try:
prompt_template = await template_func()
if prompt_template is not None:
return prompt_template
except TemplateLoadError as e:
logger.warning(f"TemplateLoadError: {str(e)}")
continue
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"An unexpected error happened when trying to load the template. "
"Trying other methods."
)
continue