Templates: Switch to Jinja2

Jinja2 is a lightweight template parser that's used in Transformers
for parsing chat completions. It's much more efficient than Fastchat
and can be imported as part of requirements.

Also allows for unblocking Pydantic's version.

Users now have to provide their own template if needed. A separate
repo may be usable for common prompt template storage.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-17 00:41:42 -05:00
committed by Brian Dashore
parent 95fd0f075e
commit f631dd6ff7
14 changed files with 115 additions and 74 deletions

View File

@@ -17,6 +17,7 @@ from exllamav2.generator import(
from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union
from templating import PromptTemplate
from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split
@@ -31,7 +32,7 @@ class ModelContainer:
draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2StreamingGenerator] = None
prompt_template: Optional[str] = None
prompt_template: Optional[PromptTemplate] = None
cache_fp8: bool = False
gpu_split_auto: bool = True
@@ -103,7 +104,20 @@ class ModelContainer:
"""
# Set prompt template override if provided
self.prompt_template = kwargs.get("prompt_template")
prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
try:
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template:
self.prompt_template = PromptTemplate(
name = prompt_template_name,
template = raw_template.read()
)
except OSError:
print("Chat completions are disabled because the provided prompt template couldn't be found.")
self.prompt_template = None
else:
print("Chat completions are disabled because a provided prompt template couldn't be found.")
self.prompt_template = None
# Set num of experts per token if provided
num_experts_override = kwargs.get("num_experts_per_token")