Templates: Switch to common function for JSON loading

Fix redundancy in code when loading templates. However, loading
a template from config.json may be a mistake since tokenizer_config.json
is the main place where chat templates are stored.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-21 23:08:51 -05:00
parent 72e19dbc12
commit 5d80a049ae
2 changed files with 46 additions and 48 deletions

View File

@@ -17,7 +17,7 @@ from exllamav2.generator import(
from gen_logging import log_generation_params, log_prompt, log_response from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union from typing import List, Optional, Union
from templating import PromptTemplate, find_template_from_model, get_template_from_model_config, get_template_from_tokenizer_config, get_template_from_file from templating import PromptTemplate, find_template_from_model, get_template_from_model_json, get_template_from_file
from utils import coalesce, unwrap from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split # Bytes to reserve on first device when loading with auto split
@@ -118,35 +118,40 @@ class ModelContainer:
# Set prompt template override if provided # Set prompt template override if provided
prompt_template_name = kwargs.get("prompt_template") prompt_template_name = kwargs.get("prompt_template")
try: if prompt_template_name:
if prompt_template_name: print(f"Attempting to load prompt template with name {prompt_template_name}")
# Read the template # Read the template
self.prompt_template = get_template_from_file(prompt_template_name) self.prompt_template = get_template_from_file(prompt_template_name)
else: else:
# Try finding the chat template from the model's config.json # Then try finding the template from the tokenizer_config.json
self.prompt_template = get_template_from_model_config( self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_config) pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
) "chat_template",
if self.prompt_template == None: "from_tokenizer_config"
self.prompt_template = get_template_from_tokenizer_config( )
model_directory
)
# If that fails, attempt fetching from model name # Try finding the chat template from the model's config.json
if self.prompt_template is None: # TODO: This may not even be used with huggingface models, mark for removal.
template_match = find_template_from_model(model_directory) if self.prompt_template is None:
if template_match: self.prompt_template = get_template_from_model_json(
self.prompt_template = get_template_from_file(template_match) pathlib.Path(self.config.model_config),
except OSError: "chat_template",
# The template or config.json couldn't be found in the user's filesystem "from_model_config"
print(f"Could not find template file with name {prompt_template_name}.jinja") )
self.prompt_template = None
# If that fails, attempt fetching from model name
if self.prompt_template is None:
template_match = find_template_from_model(model_directory)
if template_match:
self.prompt_template = get_template_from_file(template_match)
# Catch all for template lookup errors # Catch all for template lookup errors
if self.prompt_template: if self.prompt_template:
print(f"Using template {self.prompt_template.name} for chat completions.") print(f"Using template {self.prompt_template.name} for chat completions.")
else: else:
print("Chat completions are disabled because a prompt template wasn't provided or auto-detected.") print(
"Chat completions are disabled because a prompt template wasn't provided or auto-detected."
)
# Set num of experts per token if provided # Set num of experts per token if provided
num_experts_override = kwargs.get("num_experts_per_token") num_experts_override = kwargs.get("num_experts_per_token")

View File

@@ -47,36 +47,29 @@ def find_template_from_model(model_path: pathlib.Path):
# Get a template from a jinja file # Get a template from a jinja file
def get_template_from_file(prompt_template_name: str): def get_template_from_file(prompt_template_name: str):
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r", encoding = "utf8") as raw_template: template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
return PromptTemplate( if template_path.exists():
name = prompt_template_name, with open(template_path, "r", encoding = "utf8") as raw_template:
template = raw_template.read()
)
# Get a template from model config
def get_template_from_model_config(model_config_path: pathlib.Path):
with open(model_config_path, "r", encoding = "utf8") as model_config_file:
model_config = json.load(model_config_file)
chat_template = model_config.get("chat_template")
if chat_template:
return PromptTemplate( return PromptTemplate(
name = "from_model_config", name = prompt_template_name,
template = chat_template template = raw_template.read()
) )
else:
return None
# Get a template from tokenizer config return None
def get_template_from_tokenizer_config(model_dir_path: pathlib.Path):
tokenizer_config_path=model_dir_path / "tokenizer_config.json"
if tokenizer_config_path.exists(): # Get a template from a JSON file
with open(tokenizer_config_path, "r", encoding = "utf8") as tokenizer_config_file: # Requires a key and template name
tokenizer_config = json.load(tokenizer_config_file) def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
chat_template = tokenizer_config.get("chat_template") if json_path.exists:
with open(json_path, "r", encoding = "utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if chat_template: if chat_template:
return PromptTemplate( return PromptTemplate(
name = "from_tokenizer_config", name = name,
template = chat_template template = chat_template
) )
return None return None