diff --git a/OAI/utils.py b/OAI/utils.py index 7fc3126..a60e3d9 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -16,6 +16,7 @@ from typing import Optional, List try: import fastchat from fastchat.model.model_adapter import get_conversation_template + from fastchat.conversation import SeparatorStyle _fastchat_available = True except ImportError: _fastchat_available = False @@ -116,6 +117,9 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes ) conv = get_conversation_template(model_path) + if conv.sep_style is None: + conv.sep_style = SeparatorStyle.LLAMA2 + for message in messages: msg_role = message.role if msg_role == "system":