From f53c98db94e11cfef8679ca039c5302c8deed7b3 Mon Sep 17 00:00:00 2001 From: Aaron Veden Date: Wed, 20 Dec 2023 22:45:55 -0800 Subject: [PATCH] Templates: Added automatic detection of chat templates from tokenizer_config.json --- model.py | 8 ++++++-- templating.py | 17 ++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/model.py b/model.py index a5d4e2b..d2d340c 100644 --- a/model.py +++ b/model.py @@ -17,7 +17,7 @@ from exllamav2.generator import( from gen_logging import log_generation_params, log_prompt, log_response from typing import List, Optional, Union -from templating import PromptTemplate, find_template_from_model, get_template_from_config, get_template_from_file +from templating import PromptTemplate, find_template_from_model, get_template_from_model_config, get_template_from_tokenizer_config, get_template_from_file from utils import coalesce, unwrap # Bytes to reserve on first device when loading with auto split @@ -124,9 +124,13 @@ class ModelContainer: self.prompt_template = get_template_from_file(prompt_template_name) else: # Try finding the chat template from the model's config.json - self.prompt_template = get_template_from_config( + self.prompt_template = get_template_from_model_config( pathlib.Path(self.config.model_config) ) + if self.prompt_template == None: + self.prompt_template = get_template_from_tokenizer_config( + model_directory + ) # If that fails, attempt fetching from model name if self.prompt_template == None: diff --git a/templating.py b/templating.py index b19513c..bb5a545 100644 --- a/templating.py +++ b/templating.py @@ -54,7 +54,7 @@ def get_template_from_file(prompt_template_name: str): ) # Get a template from model config -def get_template_from_config(model_config_path: pathlib.Path): +def get_template_from_model_config(model_config_path: pathlib.Path): with open(model_config_path, "r", encoding = "utf8") as model_config_file: model_config = json.load(model_config_file) chat_template = model_config.get("chat_template") @@ -65,3 +65,18 @@ def get_template_from_config(model_config_path: pathlib.Path): ) else: return None + +# Get a template from tokenizer config +def get_template_from_tokenizer_config(model_dir_path: pathlib.Path): + tokenizer_config_path=model_dir_path / "tokenizer_config.json" + if tokenizer_config_path.exists(): + with open(tokenizer_config_path, "r", encoding = "utf8") as tokenizer_config_file: + tokenizer_config = json.load(tokenizer_config_file) + chat_template = tokenizer_config.get("chat_template") + if chat_template: + return PromptTemplate( + name = "from_tokenizer_config", + template = chat_template + ) + + return None