Templates: Switch to Jinja2

Jinja2 is a lightweight template parser that's used in Transformers
for parsing chat completions. It's much more efficient than Fastchat
and can be imported as part of requirements.

Also allows for unblocking Pydantic's version.

Users now have to provide their own template if needed. A separate
repo may be usable for common prompt template storage.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-17 00:41:42 -05:00
committed by Brian Dashore
parent 95fd0f075e
commit f631dd6ff7
14 changed files with 115 additions and 74 deletions

View File

@@ -10,18 +10,9 @@ from OAI.types.chat_completion import (
from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from packaging import version
from typing import Optional, List
from utils import unwrap
from typing import Optional
# Check fastchat
try:
import fastchat
from fastchat.model.model_adapter import get_conversation_template, get_conv_template
from fastchat.conversation import SeparatorStyle
_fastchat_available = True
except ImportError:
_fastchat_available = False
from utils import unwrap
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
choice = CompletionRespChoice(
@@ -110,45 +101,3 @@ def get_lora_list(lora_path: pathlib.Path):
lora_list.data.append(lora_card)
return lora_list
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):
# TODO: Replace fastchat with in-house jinja templates
# Check if fastchat is available
if not _fastchat_available:
raise ModuleNotFoundError(
"Fastchat must be installed to parse these chat completion messages.\n"
"Please run the following command: pip install fschat[model_worker]"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
f"Current version: {fastchat.__version__}\n"
"Please upgrade fastchat by running the following command: "
"pip install -U fschat[model_worker]"
)
if prompt_template:
conv = get_conv_template(prompt_template)
else:
conv = get_conversation_template(model_path)
if conv.sep_style is None:
conv.sep_style = SeparatorStyle.LLAMA2
for message in messages:
msg_role = message.role
if msg_role == "system":
conv.set_system_message(message.content)
elif msg_role == "user":
conv.append_message(conv.roles[0], message.content)
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message.content)
else:
raise ValueError(f"Unknown role: {msg_role}")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
print(prompt)
return prompt