OAI: Add ability to specify fastchat prompt template

Sometimes fastchat may not be able to detect the prompt template from
the model path. Therefore, add the ability to set it in config.yml or
via the request object itself.

Also send the provided prompt template on model info request.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-10 15:23:49 -05:00
parent 9f195af5ad
commit db87efde4a
7 changed files with 34 additions and 8 deletions

View File

@@ -25,6 +25,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
# Messages # Messages
# Take in a string as well even though it's not part of the OAI spec # Take in a string as well even though it's not part of the OAI spec
messages: Union[str, List[ChatCompletionMessage]] messages: Union[str, List[ChatCompletionMessage]]
prompt_template: Optional[str] = None
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field; from pydantic import BaseModel, Field
from time import time from time import time
from typing import Optional, List from typing import Optional, List

View File

@@ -6,6 +6,7 @@ class ModelCardParameters(BaseModel):
max_seq_len: Optional[int] = 4096 max_seq_len: Optional[int] = 4096
rope_scale: Optional[float] = 1.0 rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0
prompt_template: Optional[str] = None
draft: Optional['ModelCard'] = None draft: Optional['ModelCard'] = None
class ModelCard(BaseModel): class ModelCard(BaseModel):
@@ -34,6 +35,7 @@ class ModelLoadRequest(BaseModel):
rope_alpha: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0
no_flash_attention: Optional[bool] = False no_flash_attention: Optional[bool] = False
low_mem: Optional[bool] = False low_mem: Optional[bool] = False
prompt_template: Optional[str] = None
draft: Optional[DraftModelLoadRequest] = None draft: Optional[DraftModelLoadRequest] = None
class ModelLoadResponse(BaseModel): class ModelLoadResponse(BaseModel):

View File

@@ -1,5 +1,5 @@
import os, pathlib import pathlib
from OAI.types.completion import CompletionResponse, CompletionRespChoice, UsageStats from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.chat_completion import ( from OAI.types.chat_completion import (
ChatCompletionMessage, ChatCompletionMessage,
ChatCompletionRespChoice, ChatCompletionRespChoice,
@@ -11,13 +11,13 @@ from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard from OAI.types.model import ModelList, ModelCard
from packaging import version from packaging import version
from typing import Optional, List, Dict from typing import Optional, List
from utils import unwrap from utils import unwrap
# Check fastchat # Check fastchat
try: try:
import fastchat import fastchat
from fastchat.model.model_adapter import get_conversation_template from fastchat.model.model_adapter import get_conversation_template, get_conv_template
from fastchat.conversation import SeparatorStyle from fastchat.conversation import SeparatorStyle
_fastchat_available = True _fastchat_available = True
except ImportError: except ImportError:
@@ -111,8 +111,9 @@ def get_lora_list(lora_path: pathlib.Path):
return lora_list return lora_list
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]): def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):
# TODO: Replace fastchat with in-house jinja templates
# Check if fastchat is available # Check if fastchat is available
if not _fastchat_available: if not _fastchat_available:
raise ModuleNotFoundError( raise ModuleNotFoundError(
@@ -127,7 +128,11 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
"pip install -U fschat[model_worker]" "pip install -U fschat[model_worker]"
) )
conv = get_conversation_template(model_path) if prompt_template:
conv = get_conv_template(prompt_template)
else:
conv = get_conversation_template(model_path)
if conv.sep_style is None: if conv.sep_style is None:
conv.sep_style = SeparatorStyle.LLAMA2 conv.sep_style = SeparatorStyle.LLAMA2
@@ -145,4 +150,5 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
conv.append_message(conv.roles[1], None) conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt() prompt = conv.get_prompt()
print(prompt)
return prompt return prompt

View File

@@ -48,6 +48,10 @@ model:
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16) # Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
cache_mode: FP16 cache_mode: FP16
# Set the prompt template for this model. If empty, fastchat will automatically find the best template to use (default: None)
# NOTE: Only works with chat completion message lists!
prompt_template:
# Options for draft models (speculative decoding). This will use more VRAM! # Options for draft models (speculative decoding). This will use more VRAM!
draft: draft:
# Overrides the directory to look for draft (default: models) # Overrides the directory to look for draft (default: models)

10
main.py
View File

@@ -80,6 +80,7 @@ async def get_current_model():
rope_scale = model_container.config.scale_pos_emb, rope_scale = model_container.config.scale_pos_emb,
rope_alpha = model_container.config.scale_alpha_value, rope_alpha = model_container.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len, max_seq_len = model_container.config.max_seq_len,
prompt_template = unwrap(model_container.prompt_template, "auto")
) )
) )
@@ -302,7 +303,14 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
if isinstance(data.messages, str): if isinstance(data.messages, str):
prompt = data.messages prompt = data.messages
else: else:
prompt = get_chat_completion_prompt(model_path.name, data.messages) # If the request specified prompt template isn't found, use the one from model container
# Otherwise, let fastchat figure it out
prompt_template = unwrap(data.prompt_template, model_container.prompt_template)
try:
prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template)
except KeyError:
return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?")
if data.stream: if data.stream:
const_id = f"chatcmpl-{uuid4().hex}" const_id = f"chatcmpl-{uuid4().hex}"

View File

@@ -27,6 +27,7 @@ class ModelContainer:
draft_cache: Optional[ExLlamaV2Cache] = None draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2StreamingGenerator] = None generator: Optional[ExLlamaV2StreamingGenerator] = None
prompt_template: Optional[str] = None
cache_fp8: bool = False cache_fp8: bool = False
gpu_split_auto: bool = True gpu_split_auto: bool = True
@@ -48,6 +49,7 @@ class ModelContainer:
'max_seq_len' (int): Override model's default max sequence length (default: 4096) 'max_seq_len' (int): Override model's default max sequence length (default: 4096)
'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0) 'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0)
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0) 'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0)
'prompt_template' (str): Manually sets the prompt template for this model (default: None)
'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048) 'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048)
Inferencing in chunks reduces overall VRAM overhead by processing very long sequences in smaller Inferencing in chunks reduces overall VRAM overhead by processing very long sequences in smaller
batches. This limits the size of temporary buffers needed for the hidden state and attention batches. This limits the size of temporary buffers needed for the hidden state and attention
@@ -93,6 +95,9 @@ class ModelContainer:
self.config.set_low_mem() self.config.set_low_mem()
""" """
# Set prompt template override if provided
self.prompt_template = kwargs.get("prompt_template")
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len) chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
self.config.max_input_len = chunk_size self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2 self.config.max_attn_size = chunk_size ** 2