Templates: Switch to common function for JSON loading

Fix redundancy in code when loading templates. However, loading
a template from config.json may be a mistake since tokenizer_config.json
is the main place where chat templates are stored.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-21 23:08:51 -05:00
parent 72e19dbc12
commit 5d80a049ae
2 changed files with 46 additions and 48 deletions

View File

@@ -47,36 +47,29 @@ def find_template_from_model(model_path: pathlib.Path):
# Get a template from a jinja file
def get_template_from_file(prompt_template_name: str):
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r", encoding = "utf8") as raw_template:
return PromptTemplate(
name = prompt_template_name,
template = raw_template.read()
)
# Get a template from model config
def get_template_from_model_config(model_config_path: pathlib.Path):
with open(model_config_path, "r", encoding = "utf8") as model_config_file:
model_config = json.load(model_config_file)
chat_template = model_config.get("chat_template")
if chat_template:
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists():
with open(template_path, "r", encoding = "utf8") as raw_template:
return PromptTemplate(
name = "from_model_config",
template = chat_template
name = prompt_template_name,
template = raw_template.read()
)
else:
return None
# Get a template from tokenizer config
def get_template_from_tokenizer_config(model_dir_path: pathlib.Path):
tokenizer_config_path=model_dir_path / "tokenizer_config.json"
if tokenizer_config_path.exists():
with open(tokenizer_config_path, "r", encoding = "utf8") as tokenizer_config_file:
tokenizer_config = json.load(tokenizer_config_file)
chat_template = tokenizer_config.get("chat_template")
return None
# Get a template from a JSON file
# Requires a key and template name
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
if json_path.exists:
with open(json_path, "r", encoding = "utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if chat_template:
return PromptTemplate(
name = "from_tokenizer_config",
name = name,
template = chat_template
)
return None