Templates: Added automatic detection of chat templates from tokenizer_config.json

This commit is contained in:
Aaron Veden
2023-12-20 22:45:55 -08:00
parent bee758dae9
commit f53c98db94
2 changed files with 22 additions and 3 deletions

View File

@@ -17,7 +17,7 @@ from exllamav2.generator import(
from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union
from templating import PromptTemplate, find_template_from_model, get_template_from_config, get_template_from_file
from templating import PromptTemplate, find_template_from_model, get_template_from_model_config, get_template_from_tokenizer_config, get_template_from_file
from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split
@@ -124,9 +124,13 @@ class ModelContainer:
self.prompt_template = get_template_from_file(prompt_template_name)
else:
# Try finding the chat template from the model's config.json
self.prompt_template = get_template_from_config(
self.prompt_template = get_template_from_model_config(
pathlib.Path(self.config.model_config)
)
if self.prompt_template == None:
self.prompt_template = get_template_from_tokenizer_config(
model_directory
)
# If that fails, attempt fetching from model name
if self.prompt_template == None:

View File

@@ -54,7 +54,7 @@ def get_template_from_file(prompt_template_name: str):
)
# Get a template from model config
def get_template_from_config(model_config_path: pathlib.Path):
def get_template_from_model_config(model_config_path: pathlib.Path):
with open(model_config_path, "r", encoding = "utf8") as model_config_file:
model_config = json.load(model_config_file)
chat_template = model_config.get("chat_template")
@@ -65,3 +65,18 @@ def get_template_from_config(model_config_path: pathlib.Path):
)
else:
return None
# Get a template from tokenizer config
def get_template_from_tokenizer_config(model_dir_path: pathlib.Path):
tokenizer_config_path=model_dir_path / "tokenizer_config.json"
if tokenizer_config_path.exists():
with open(tokenizer_config_path, "r", encoding = "utf8") as tokenizer_config_file:
tokenizer_config = json.load(tokenizer_config_file)
chat_template = tokenizer_config.get("chat_template")
if chat_template:
return PromptTemplate(
name = "from_tokenizer_config",
template = chat_template
)
return None