mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Templates: Attempt loading from model config
Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import pathlib
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import version as package_version
|
||||
@@ -33,9 +34,34 @@ def _compile_template(template: str):
|
||||
jinja_template = jinja_env.from_string(template)
|
||||
return jinja_template
|
||||
|
||||
# Find a matching template name from a model path
|
||||
def find_template_from_model(model_path: pathlib.Path):
|
||||
model_name = model_path.name
|
||||
template_directory = pathlib.Path("templates")
|
||||
for filepath in template_directory.glob("*.jinja"):
|
||||
template_name = filepath.stem.lower()
|
||||
|
||||
# Check if the template name is present in the model name
|
||||
if template_name in model_name.lower():
|
||||
return template_name
|
||||
|
||||
# Get a template from a jinja file
|
||||
def get_template_from_file(prompt_template_name: str):
|
||||
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template:
|
||||
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r", encoding = "utf8") as raw_template:
|
||||
return PromptTemplate(
|
||||
name = prompt_template_name,
|
||||
template = raw_template.read()
|
||||
)
|
||||
|
||||
# Get a template from model config
|
||||
def get_template_from_config(model_config_path: pathlib.Path):
|
||||
with open(model_config_path, "r", encoding = "utf8") as model_config_file:
|
||||
model_config = json.load(model_config_file)
|
||||
chat_template = model_config.get("chat_template")
|
||||
if chat_template:
|
||||
return PromptTemplate(
|
||||
name = "from_model_config",
|
||||
template = chat_template
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user