From 5d80a049ae3eb93f2275e8c3d316e27311ead047 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 21 Dec 2023 23:08:51 -0500 Subject: [PATCH] Templates: Switch to common function for JSON loading Fix redundancy in code when loading templates. However, loading a template from config.json may be a mistake since tokenizer_config.json is the main place where chat templates are stored. Signed-off-by: kingbri --- model.py | 53 ++++++++++++++++++++++++++++----------------------- templating.py | 41 +++++++++++++++++---------------------- 2 files changed, 46 insertions(+), 48 deletions(-) diff --git a/model.py b/model.py index 088a506..41d9889 100644 --- a/model.py +++ b/model.py @@ -17,7 +17,7 @@ from exllamav2.generator import( from gen_logging import log_generation_params, log_prompt, log_response from typing import List, Optional, Union -from templating import PromptTemplate, find_template_from_model, get_template_from_model_config, get_template_from_tokenizer_config, get_template_from_file +from templating import PromptTemplate, find_template_from_model, get_template_from_model_json, get_template_from_file from utils import coalesce, unwrap # Bytes to reserve on first device when loading with auto split @@ -118,35 +118,40 @@ class ModelContainer: # Set prompt template override if provided prompt_template_name = kwargs.get("prompt_template") - try: - if prompt_template_name: - # Read the template - self.prompt_template = get_template_from_file(prompt_template_name) - else: - # Try finding the chat template from the model's config.json - self.prompt_template = get_template_from_model_config( - pathlib.Path(self.config.model_config) - ) - if self.prompt_template == None: - self.prompt_template = get_template_from_tokenizer_config( - model_directory - ) + if prompt_template_name: + print(f"Attempting to load prompt template with name {prompt_template_name}") + # Read the template + self.prompt_template = get_template_from_file(prompt_template_name) + else: + # Then try finding the template from the tokenizer_config.json + self.prompt_template = get_template_from_model_json( + pathlib.Path(self.config.model_dir) / "tokenizer_config.json", + "chat_template", + "from_tokenizer_config" + ) - # If that fails, attempt fetching from model name - if self.prompt_template is None: - template_match = find_template_from_model(model_directory) - if template_match: - self.prompt_template = get_template_from_file(template_match) - except OSError: - # The template or config.json couldn't be found in the user's filesystem - print(f"Could not find template file with name {prompt_template_name}.jinja") - self.prompt_template = None + # Try finding the chat template from the model's config.json + # TODO: This may not even be used with huggingface models, mark for removal. + if self.prompt_template is None: + self.prompt_template = get_template_from_model_json( + pathlib.Path(self.config.model_config), + "chat_template", + "from_model_config" + ) + + # If that fails, attempt fetching from model name + if self.prompt_template is None: + template_match = find_template_from_model(model_directory) + if template_match: + self.prompt_template = get_template_from_file(template_match) # Catch all for template lookup errors if self.prompt_template: print(f"Using template {self.prompt_template.name} for chat completions.") else: - print("Chat completions are disabled because a prompt template wasn't provided or auto-detected.") + print( + "Chat completions are disabled because a prompt template wasn't provided or auto-detected." + ) # Set num of experts per token if provided num_experts_override = kwargs.get("num_experts_per_token") diff --git a/templating.py b/templating.py index bb5a545..fe9207d 100644 --- a/templating.py +++ b/templating.py @@ -47,36 +47,29 @@ def find_template_from_model(model_path: pathlib.Path): # Get a template from a jinja file def get_template_from_file(prompt_template_name: str): - with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r", encoding = "utf8") as raw_template: - return PromptTemplate( - name = prompt_template_name, - template = raw_template.read() - ) - -# Get a template from model config -def get_template_from_model_config(model_config_path: pathlib.Path): - with open(model_config_path, "r", encoding = "utf8") as model_config_file: - model_config = json.load(model_config_file) - chat_template = model_config.get("chat_template") - if chat_template: + template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja") + if template_path.exists(): + with open(template_path, "r", encoding = "utf8") as raw_template: return PromptTemplate( - name = "from_model_config", - template = chat_template + name = prompt_template_name, + template = raw_template.read() ) - else: - return None -# Get a template from tokenizer config -def get_template_from_tokenizer_config(model_dir_path: pathlib.Path): - tokenizer_config_path=model_dir_path / "tokenizer_config.json" - if tokenizer_config_path.exists(): - with open(tokenizer_config_path, "r", encoding = "utf8") as tokenizer_config_file: - tokenizer_config = json.load(tokenizer_config_file) - chat_template = tokenizer_config.get("chat_template") + return None + + +# Get a template from a JSON file +# Requires a key and template name +def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str): + if json_path.exists: + with open(json_path, "r", encoding = "utf8") as config_file: + model_config = json.load(config_file) + chat_template = model_config.get(key) if chat_template: return PromptTemplate( - name = "from_tokenizer_config", + name = name, template = chat_template ) return None +