Templates: Attempt loading from model config

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-19 22:58:47 -05:00
parent da69ad8cd3
commit d3246747c0
2 changed files with 39 additions and 12 deletions

View File

@@ -17,7 +17,7 @@ from exllamav2.generator import(
from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union
from templating import PromptTemplate, get_template_from_file
from templating import PromptTemplate, find_template_from_model, get_template_from_config, get_template_from_file
from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split
@@ -110,17 +110,18 @@ class ModelContainer:
# Read the template
self.prompt_template = get_template_from_file(prompt_template_name)
else:
# Try autodetection of the template from the model path name
model_name = model_directory.name
template_directory = pathlib.Path("templates")
for filepath in template_directory.glob("*.jinja"):
template_name = filepath.stem.lower()
# Check if the template name is present in the model name
if template_name in model_name.lower():
self.prompt_template = get_template_from_file(template_name)
break
# Try finding the chat template from the model's config.json
self.prompt_template = get_template_from_config(
pathlib.Path(self.config.model_config)
)
# If that fails, attempt fetching from model name
if self.prompt_template == None:
template_match = find_template_from_model(model_directory)
if template_match:
self.prompt_template = get_template_from_file(template_match)
except OSError:
# The template couldn't be found in the user's filesystem
# The template or config.json couldn't be found in the user's filesystem
print(f"Could not find template file with name {prompt_template_name}.jinja")
self.prompt_template = None

View File

@@ -1,3 +1,4 @@
import json
import pathlib
from functools import lru_cache
from importlib.metadata import version as package_version
@@ -33,9 +34,34 @@ def _compile_template(template: str):
jinja_template = jinja_env.from_string(template)
return jinja_template
# Find a matching template name from a model path
def find_template_from_model(model_path: pathlib.Path):
model_name = model_path.name
template_directory = pathlib.Path("templates")
for filepath in template_directory.glob("*.jinja"):
template_name = filepath.stem.lower()
# Check if the template name is present in the model name
if template_name in model_name.lower():
return template_name
# Get a template from a jinja file
def get_template_from_file(prompt_template_name: str):
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template:
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r", encoding = "utf8") as raw_template:
return PromptTemplate(
name = prompt_template_name,
template = raw_template.read()
)
# Get a template from model config
def get_template_from_config(model_config_path: pathlib.Path):
with open(model_config_path, "r", encoding = "utf8") as model_config_file:
model_config = json.load(model_config_file)
chat_template = model_config.get("chat_template")
if chat_template:
return PromptTemplate(
name = "from_model_config",
template = chat_template
)
else:
return None