mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-30 03:01:44 +00:00
Templates: Switch to common function for JSON loading
Fix redundancy in code when loading templates. However, loading a template from config.json may be a mistake since tokenizer_config.json is the main place where chat templates are stored. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
53
model.py
53
model.py
@@ -17,7 +17,7 @@ from exllamav2.generator import(
|
|||||||
|
|
||||||
from gen_logging import log_generation_params, log_prompt, log_response
|
from gen_logging import log_generation_params, log_prompt, log_response
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
from templating import PromptTemplate, find_template_from_model, get_template_from_model_config, get_template_from_tokenizer_config, get_template_from_file
|
from templating import PromptTemplate, find_template_from_model, get_template_from_model_json, get_template_from_file
|
||||||
from utils import coalesce, unwrap
|
from utils import coalesce, unwrap
|
||||||
|
|
||||||
# Bytes to reserve on first device when loading with auto split
|
# Bytes to reserve on first device when loading with auto split
|
||||||
@@ -118,35 +118,40 @@ class ModelContainer:
|
|||||||
|
|
||||||
# Set prompt template override if provided
|
# Set prompt template override if provided
|
||||||
prompt_template_name = kwargs.get("prompt_template")
|
prompt_template_name = kwargs.get("prompt_template")
|
||||||
try:
|
if prompt_template_name:
|
||||||
if prompt_template_name:
|
print(f"Attempting to load prompt template with name {prompt_template_name}")
|
||||||
# Read the template
|
# Read the template
|
||||||
self.prompt_template = get_template_from_file(prompt_template_name)
|
self.prompt_template = get_template_from_file(prompt_template_name)
|
||||||
else:
|
else:
|
||||||
# Try finding the chat template from the model's config.json
|
# Then try finding the template from the tokenizer_config.json
|
||||||
self.prompt_template = get_template_from_model_config(
|
self.prompt_template = get_template_from_model_json(
|
||||||
pathlib.Path(self.config.model_config)
|
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||||
)
|
"chat_template",
|
||||||
if self.prompt_template == None:
|
"from_tokenizer_config"
|
||||||
self.prompt_template = get_template_from_tokenizer_config(
|
)
|
||||||
model_directory
|
|
||||||
)
|
|
||||||
|
|
||||||
# If that fails, attempt fetching from model name
|
# Try finding the chat template from the model's config.json
|
||||||
if self.prompt_template is None:
|
# TODO: This may not even be used with huggingface models, mark for removal.
|
||||||
template_match = find_template_from_model(model_directory)
|
if self.prompt_template is None:
|
||||||
if template_match:
|
self.prompt_template = get_template_from_model_json(
|
||||||
self.prompt_template = get_template_from_file(template_match)
|
pathlib.Path(self.config.model_config),
|
||||||
except OSError:
|
"chat_template",
|
||||||
# The template or config.json couldn't be found in the user's filesystem
|
"from_model_config"
|
||||||
print(f"Could not find template file with name {prompt_template_name}.jinja")
|
)
|
||||||
self.prompt_template = None
|
|
||||||
|
# If that fails, attempt fetching from model name
|
||||||
|
if self.prompt_template is None:
|
||||||
|
template_match = find_template_from_model(model_directory)
|
||||||
|
if template_match:
|
||||||
|
self.prompt_template = get_template_from_file(template_match)
|
||||||
|
|
||||||
# Catch all for template lookup errors
|
# Catch all for template lookup errors
|
||||||
if self.prompt_template:
|
if self.prompt_template:
|
||||||
print(f"Using template {self.prompt_template.name} for chat completions.")
|
print(f"Using template {self.prompt_template.name} for chat completions.")
|
||||||
else:
|
else:
|
||||||
print("Chat completions are disabled because a prompt template wasn't provided or auto-detected.")
|
print(
|
||||||
|
"Chat completions are disabled because a prompt template wasn't provided or auto-detected."
|
||||||
|
)
|
||||||
|
|
||||||
# Set num of experts per token if provided
|
# Set num of experts per token if provided
|
||||||
num_experts_override = kwargs.get("num_experts_per_token")
|
num_experts_override = kwargs.get("num_experts_per_token")
|
||||||
|
|||||||
@@ -47,36 +47,29 @@ def find_template_from_model(model_path: pathlib.Path):
|
|||||||
|
|
||||||
# Get a template from a jinja file
|
# Get a template from a jinja file
|
||||||
def get_template_from_file(prompt_template_name: str):
|
def get_template_from_file(prompt_template_name: str):
|
||||||
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r", encoding = "utf8") as raw_template:
|
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
||||||
return PromptTemplate(
|
if template_path.exists():
|
||||||
name = prompt_template_name,
|
with open(template_path, "r", encoding = "utf8") as raw_template:
|
||||||
template = raw_template.read()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get a template from model config
|
|
||||||
def get_template_from_model_config(model_config_path: pathlib.Path):
|
|
||||||
with open(model_config_path, "r", encoding = "utf8") as model_config_file:
|
|
||||||
model_config = json.load(model_config_file)
|
|
||||||
chat_template = model_config.get("chat_template")
|
|
||||||
if chat_template:
|
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
name = "from_model_config",
|
name = prompt_template_name,
|
||||||
template = chat_template
|
template = raw_template.read()
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get a template from tokenizer config
|
return None
|
||||||
def get_template_from_tokenizer_config(model_dir_path: pathlib.Path):
|
|
||||||
tokenizer_config_path=model_dir_path / "tokenizer_config.json"
|
|
||||||
if tokenizer_config_path.exists():
|
# Get a template from a JSON file
|
||||||
with open(tokenizer_config_path, "r", encoding = "utf8") as tokenizer_config_file:
|
# Requires a key and template name
|
||||||
tokenizer_config = json.load(tokenizer_config_file)
|
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
|
||||||
chat_template = tokenizer_config.get("chat_template")
|
if json_path.exists:
|
||||||
|
with open(json_path, "r", encoding = "utf8") as config_file:
|
||||||
|
model_config = json.load(config_file)
|
||||||
|
chat_template = model_config.get(key)
|
||||||
if chat_template:
|
if chat_template:
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
name = "from_tokenizer_config",
|
name = name,
|
||||||
template = chat_template
|
template = chat_template
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user