OAI: Add ability to specify fastchat prompt template

Sometimes fastchat may not be able to detect the prompt template from
the model path. Therefore, add the ability to set it in config.yml or
via the request object itself.

Also send the provided prompt template on model info request.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-10 15:23:49 -05:00
parent 9f195af5ad
commit db87efde4a
7 changed files with 34 additions and 8 deletions

View File

@@ -1,5 +1,5 @@
import os, pathlib
from OAI.types.completion import CompletionResponse, CompletionRespChoice, UsageStats
import pathlib
from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.chat_completion import (
ChatCompletionMessage,
ChatCompletionRespChoice,
@@ -11,13 +11,13 @@ from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from packaging import version
from typing import Optional, List, Dict
from typing import Optional, List
from utils import unwrap
# Check fastchat
try:
import fastchat
from fastchat.model.model_adapter import get_conversation_template
from fastchat.model.model_adapter import get_conversation_template, get_conv_template
from fastchat.conversation import SeparatorStyle
_fastchat_available = True
except ImportError:
@@ -111,8 +111,9 @@ def get_lora_list(lora_path: pathlib.Path):
return lora_list
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):
# TODO: Replace fastchat with in-house jinja templates
# Check if fastchat is available
if not _fastchat_available:
raise ModuleNotFoundError(
@@ -127,7 +128,11 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
"pip install -U fschat[model_worker]"
)
conv = get_conversation_template(model_path)
if prompt_template:
conv = get_conv_template(prompt_template)
else:
conv = get_conversation_template(model_path)
if conv.sep_style is None:
conv.sep_style = SeparatorStyle.LLAMA2
@@ -145,4 +150,5 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
print(prompt)
return prompt