diff --git a/model.py b/model.py index a5d4e2b..d2d340c 100644 --- a/model.py +++ b/model.py @@ -17,7 +17,7 @@ from exllamav2.generator import( from gen_logging import log_generation_params, log_prompt, log_response from typing import List, Optional, Union -from templating import PromptTemplate, find_template_from_model, get_template_from_config, get_template_from_file +from templating import PromptTemplate, find_template_from_model, get_template_from_model_config, get_template_from_tokenizer_config, get_template_from_file from utils import coalesce, unwrap # Bytes to reserve on first device when loading with auto split @@ -124,9 +124,13 @@ class ModelContainer: self.prompt_template = get_template_from_file(prompt_template_name) else: # Try finding the chat template from the model's config.json - self.prompt_template = get_template_from_config( + self.prompt_template = get_template_from_model_config( pathlib.Path(self.config.model_config) ) + if self.prompt_template == None: + self.prompt_template = get_template_from_tokenizer_config( + model_directory + ) # If that fails, attempt fetching from model name if self.prompt_template == None: diff --git a/templating.py b/templating.py index b19513c..bb5a545 100644 --- a/templating.py +++ b/templating.py @@ -54,7 +54,7 @@ def get_template_from_file(prompt_template_name: str): ) # Get a template from model config -def get_template_from_config(model_config_path: pathlib.Path): +def get_template_from_model_config(model_config_path: pathlib.Path): with open(model_config_path, "r", encoding = "utf8") as model_config_file: model_config = json.load(model_config_file) chat_template = model_config.get("chat_template") @@ -65,3 +65,18 @@ def get_template_from_config(model_config_path: pathlib.Path): ) else: return None + +# Get a template from tokenizer config +def get_template_from_tokenizer_config(model_dir_path: pathlib.Path): + tokenizer_config_path=model_dir_path / "tokenizer_config.json" + if tokenizer_config_path.exists(): + with open(tokenizer_config_path, "r", encoding = "utf8") as tokenizer_config_file: + tokenizer_config = json.load(tokenizer_config_file) + chat_template = tokenizer_config.get("chat_template") + if chat_template: + return PromptTemplate( + name = "from_tokenizer_config", + template = chat_template + ) + + return None