Model: Move find_template function to templating

Makes sense to extract to a utility function instead.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri
2025-04-20 18:27:53 -04:00
parent 8e238fa8f6
commit 13beef8021
2 changed files with 58 additions and 62 deletions

View File

@@ -49,11 +49,7 @@ from common.gen_logging import (
from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import (
PromptTemplate,
TemplateLoadError,
find_template_from_model,
)
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig
from common.utils import calculate_rope_alpha, coalesce, unwrap
@@ -322,7 +318,7 @@ class ExllamaV2Container(BaseModelContainer):
self.cache_size = self.config.max_seq_len
# Try to set prompt template
self.prompt_template = await self.find_prompt_template(
self.prompt_template = await find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
@@ -383,62 +379,6 @@ class ExllamaV2Container(BaseModelContainer):
# Return the created instance
return self
async def find_prompt_template(self, prompt_template_name, model_directory):
"""Tries to find a prompt template using various methods."""
logger.info("Attempting to load a prompt template if present.")
find_template_functions = [
lambda: PromptTemplate.from_model_json(
pathlib.Path(self.config.model_dir) / "chat_template.json",
key="chat_template",
),
lambda: PromptTemplate.from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
key="chat_template",
),
lambda: PromptTemplate.from_file(find_template_from_model(model_directory)),
]
# Find the template in the model directory if it exists
model_dir_template_path = (
pathlib.Path(self.config.model_dir) / "tabby_template.jinja"
)
if model_dir_template_path.exists():
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(model_dir_template_path)
]
# Add lookup from prompt template name if provided
if prompt_template_name:
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(
pathlib.Path("templates") / prompt_template_name
),
lambda: PromptTemplate.from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
key="chat_template",
name=prompt_template_name,
),
]
# Continue on exception since functions are tried as they fail
for template_func in find_template_functions:
try:
prompt_template = await template_func()
if prompt_template is not None:
return prompt_template
except TemplateLoadError as e:
logger.warning(f"TemplateLoadError: {str(e)}")
continue
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"An unexpected error happened when trying to load the template. "
"Trying other methods."
)
continue
def get_model_parameters(self):
model_params = {
"name": self.model_dir.name,

View File

@@ -1,5 +1,6 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
import traceback
import aiofiles
import json
import pathlib
@@ -211,3 +212,58 @@ def find_template_from_model(model_path: pathlib.Path):
return template_name
else:
raise TemplateLoadError("Could not find template from model name.")
async def find_prompt_template(template_name, model_dir: pathlib.Path):
"""Tries to find a prompt template using various methods."""
logger.info("Attempting to load a prompt template if present.")
find_template_functions = [
lambda: PromptTemplate.from_model_json(
model_dir / "chat_template.json",
key="chat_template",
),
lambda: PromptTemplate.from_model_json(
model_dir / "tokenizer_config.json",
key="chat_template",
),
lambda: PromptTemplate.from_file(find_template_from_model(model_dir)),
]
# Find the template in the model directory if it exists
model_dir_template_path = model_dir / "tabby_template.jinja"
if model_dir_template_path.exists():
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(model_dir_template_path)
]
# Add lookup from prompt template name if provided
if template_name:
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(
pathlib.Path("templates") / template_name
),
lambda: PromptTemplate.from_model_json(
model_dir / "tokenizer_config.json",
key="chat_template",
name=template_name,
),
]
# Continue on exception since functions are tried as they fail
for template_func in find_template_functions:
try:
prompt_template = await template_func()
if prompt_template is not None:
return prompt_template
except TemplateLoadError as e:
logger.warning(f"TemplateLoadError: {str(e)}")
continue
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"An unexpected error happened when trying to load the template. "
"Trying other methods."
)
continue