diff --git a/model.py b/model.py index daa5f65..e39a478 100644 --- a/model.py +++ b/model.py @@ -17,7 +17,7 @@ from exllamav2.generator import( from gen_logging import log_generation_params, log_prompt, log_response from typing import List, Optional, Union -from templating import PromptTemplate, get_template_from_file +from templating import PromptTemplate, find_template_from_model, get_template_from_config, get_template_from_file from utils import coalesce, unwrap # Bytes to reserve on first device when loading with auto split @@ -110,17 +110,18 @@ class ModelContainer: # Read the template self.prompt_template = get_template_from_file(prompt_template_name) else: - # Try autodetection of the template from the model path name - model_name = model_directory.name - template_directory = pathlib.Path("templates") - for filepath in template_directory.glob("*.jinja"): - template_name = filepath.stem.lower() - # Check if the template name is present in the model name - if template_name in model_name.lower(): - self.prompt_template = get_template_from_file(template_name) - break + # Try finding the chat template from the model's config.json + self.prompt_template = get_template_from_config( + pathlib.Path(self.config.model_config) + ) + + # If that fails, attempt fetching from model name + if self.prompt_template == None: + template_match = find_template_from_model(model_directory) + if template_match: + self.prompt_template = get_template_from_file(template_match) except OSError: - # The template couldn't be found in the user's filesystem + # The template or config.json couldn't be found in the user's filesystem print(f"Could not find template file with name {prompt_template_name}.jinja") self.prompt_template = None diff --git a/templating.py b/templating.py index cb19510..b19513c 100644 --- a/templating.py +++ b/templating.py @@ -1,3 +1,4 @@ +import json import pathlib from functools import lru_cache from importlib.metadata import version as package_version @@ -33,9 +34,34 @@ def _compile_template(template: str): jinja_template = jinja_env.from_string(template) return jinja_template +# Find a matching template name from a model path +def find_template_from_model(model_path: pathlib.Path): + model_name = model_path.name + template_directory = pathlib.Path("templates") + for filepath in template_directory.glob("*.jinja"): + template_name = filepath.stem.lower() + + # Check if the template name is present in the model name + if template_name in model_name.lower(): + return template_name + +# Get a template from a jinja file def get_template_from_file(prompt_template_name: str): - with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template: + with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r", encoding = "utf8") as raw_template: return PromptTemplate( name = prompt_template_name, template = raw_template.read() ) + +# Get a template from model config +def get_template_from_config(model_config_path: pathlib.Path): + with open(model_config_path, "r", encoding = "utf8") as model_config_file: + model_config = json.load(model_config_file) + chat_template = model_config.get("chat_template") + if chat_template: + return PromptTemplate( + name = "from_model_config", + template = chat_template + ) + else: + return None