mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Templates: Switch to Jinja2
Jinja2 is a lightweight template parser that's used in Transformers for parsing chat completions. It's much more efficient than Fastchat and can be imported as part of requirements. Also allows for unblocking Pydantic's version. Users now have to provide their own template if needed. A separate repo may be usable for common prompt template storage. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
55
OAI/utils.py
55
OAI/utils.py
@@ -10,18 +10,9 @@ from OAI.types.chat_completion import (
|
||||
from OAI.types.common import UsageStats
|
||||
from OAI.types.lora import LoraList, LoraCard
|
||||
from OAI.types.model import ModelList, ModelCard
|
||||
from packaging import version
|
||||
from typing import Optional, List
|
||||
from utils import unwrap
|
||||
from typing import Optional
|
||||
|
||||
# Check fastchat
|
||||
try:
|
||||
import fastchat
|
||||
from fastchat.model.model_adapter import get_conversation_template, get_conv_template
|
||||
from fastchat.conversation import SeparatorStyle
|
||||
_fastchat_available = True
|
||||
except ImportError:
|
||||
_fastchat_available = False
|
||||
from utils import unwrap
|
||||
|
||||
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
|
||||
choice = CompletionRespChoice(
|
||||
@@ -110,45 +101,3 @@ def get_lora_list(lora_path: pathlib.Path):
|
||||
lora_list.data.append(lora_card)
|
||||
|
||||
return lora_list
|
||||
|
||||
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):
|
||||
|
||||
# TODO: Replace fastchat with in-house jinja templates
|
||||
# Check if fastchat is available
|
||||
if not _fastchat_available:
|
||||
raise ModuleNotFoundError(
|
||||
"Fastchat must be installed to parse these chat completion messages.\n"
|
||||
"Please run the following command: pip install fschat[model_worker]"
|
||||
)
|
||||
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
|
||||
raise ImportError(
|
||||
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
|
||||
f"Current version: {fastchat.__version__}\n"
|
||||
"Please upgrade fastchat by running the following command: "
|
||||
"pip install -U fschat[model_worker]"
|
||||
)
|
||||
|
||||
if prompt_template:
|
||||
conv = get_conv_template(prompt_template)
|
||||
else:
|
||||
conv = get_conversation_template(model_path)
|
||||
|
||||
if conv.sep_style is None:
|
||||
conv.sep_style = SeparatorStyle.LLAMA2
|
||||
|
||||
for message in messages:
|
||||
msg_role = message.role
|
||||
if msg_role == "system":
|
||||
conv.set_system_message(message.content)
|
||||
elif msg_role == "user":
|
||||
conv.append_message(conv.roles[0], message.content)
|
||||
elif msg_role == "assistant":
|
||||
conv.append_message(conv.roles[1], message.content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
print(prompt)
|
||||
return prompt
|
||||
|
||||
Reference in New Issue
Block a user