Templates: Add auto-detection from path

This replicates FastChat's model path detection.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-17 15:41:09 -05:00
committed by Brian Dashore
parent e895eaa4bd
commit 7cbc08fc72
2 changed files with 30 additions and 14 deletions

View File

@@ -17,7 +17,7 @@ from exllamav2.generator import(
from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union
from templating import PromptTemplate
from templating import PromptTemplate, get_template_from_file
from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split
@@ -105,20 +105,28 @@ class ModelContainer:
# Set prompt template override if provided
prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
try:
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template:
self.prompt_template = PromptTemplate(
name = prompt_template_name,
template = raw_template.read()
)
except OSError:
print("Chat completions are disabled because the provided prompt template couldn't be found.")
self.prompt_template = None
else:
print("Chat completions are disabled because a provided prompt template couldn't be found.")
try:
if prompt_template_name:
# Read the template
self.prompt_template = get_template_from_file(prompt_template_name)
else:
# Try autodetection of the template from the model path name
model_name = model_directory.name
template_directory = pathlib.Path("templates")
for filepath in template_directory.glob("*.jinja"):
template_name = filepath.stem.lower()
# Check if the template name is present in the model name
if template_name in model_name.lower():
self.prompt_template = get_template_from_file(template_name)
break
except OSError:
# Silently set the prompt template to none on a file lookup error
self.prompt_template = None
# Catch all for template lookup errors
if self.prompt_template is None:
print("Chat completions are disabled because a prompt template wasn't provided or auto-detected.")
# Set num of experts per token if provided
num_experts_override = kwargs.get("num_experts_per_token")
if num_experts_override: