Templates: Switch to Jinja2

Jinja2 is a lightweight template parser that's used in Transformers
for parsing chat completions. It's much more efficient than Fastchat
and can be imported as part of requirements.

Also allows for unblocking Pydantic's version.

Users now have to provide their own template if needed. A separate
repo may be usable for common prompt template storage.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-17 00:41:42 -05:00
committed by Brian Dashore
parent 95fd0f075e
commit f631dd6ff7
14 changed files with 115 additions and 74 deletions

View File

@@ -1,7 +1,7 @@
from uuid import uuid4 from uuid import uuid4
from time import time from time import time
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Union, List, Optional from typing import Union, List, Optional, Dict
from OAI.types.common import UsageStats, CommonCompletionRequest from OAI.types.common import UsageStats, CommonCompletionRequest
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
@@ -24,7 +24,7 @@ class ChatCompletionStreamChoice(BaseModel):
class ChatCompletionRequest(CommonCompletionRequest): class ChatCompletionRequest(CommonCompletionRequest):
# Messages # Messages
# Take in a string as well even though it's not part of the OAI spec # Take in a string as well even though it's not part of the OAI spec
messages: Union[str, List[ChatCompletionMessage]] messages: Union[str, List[Dict[str, str]]]
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):

View File

@@ -10,18 +10,9 @@ from OAI.types.chat_completion import (
from OAI.types.common import UsageStats from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard from OAI.types.model import ModelList, ModelCard
from packaging import version from typing import Optional
from typing import Optional, List
from utils import unwrap
# Check fastchat from utils import unwrap
try:
import fastchat
from fastchat.model.model_adapter import get_conversation_template, get_conv_template
from fastchat.conversation import SeparatorStyle
_fastchat_available = True
except ImportError:
_fastchat_available = False
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]): def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
choice = CompletionRespChoice( choice = CompletionRespChoice(
@@ -110,45 +101,3 @@ def get_lora_list(lora_path: pathlib.Path):
lora_list.data.append(lora_card) lora_list.data.append(lora_card)
return lora_list return lora_list
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):
# TODO: Replace fastchat with in-house jinja templates
# Check if fastchat is available
if not _fastchat_available:
raise ModuleNotFoundError(
"Fastchat must be installed to parse these chat completion messages.\n"
"Please run the following command: pip install fschat[model_worker]"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
f"Current version: {fastchat.__version__}\n"
"Please upgrade fastchat by running the following command: "
"pip install -U fschat[model_worker]"
)
if prompt_template:
conv = get_conv_template(prompt_template)
else:
conv = get_conversation_template(model_path)
if conv.sep_style is None:
conv.sep_style = SeparatorStyle.LLAMA2
for message in messages:
msg_role = message.role
if msg_role == "system":
conv.set_system_message(message.content)
elif msg_role == "user":
conv.append_message(conv.roles[0], message.content)
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message.content)
else:
raise ValueError(f"Unknown role: {msg_role}")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
print(prompt)
return prompt

View File

@@ -54,8 +54,6 @@ NOTE: For Flash Attention 2 to work on Windows, CUDA 12.x **must** be installed!
3. ROCm 5.6: `pip install -r requirements-amd.txt` 3. ROCm 5.6: `pip install -r requirements-amd.txt`
5. If you want the `/v1/chat/completions` endpoint to work with a list of messages, install fastchat by running `pip install fschat[model_worker]`
## Configuration ## Configuration
A config.yml file is required for overriding project defaults. If you are okay with the defaults, you don't need a config file! A config.yml file is required for overriding project defaults. If you are okay with the defaults, you don't need a config file!
@@ -126,6 +124,12 @@ All routes require an API key except for the following which require an **admin*
- `/v1/model/unload` - `/v1/model/unload`
## Chat Completions
`/v1/chat/completions` now uses Jinja2 for templating. Please read [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for more information of how chat templates work.
Also make sure to set the template name in `config.yml` to the template's filename.
## Common Issues ## Common Issues
- AMD cards will error out with flash attention installed, even if the config option is set to False. Run `pip uninstall flash_attn` to remove the wheel from your system. - AMD cards will error out with flash attention installed, even if the config option is set to False. Run `pip uninstall flash_attn` to remove the wheel from your system.

View File

@@ -56,9 +56,9 @@ model:
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16) # Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
cache_mode: FP16 cache_mode: FP16
# Set the prompt template for this model. If empty, fastchat will automatically find the best template to use (default: None) # Set the prompt template for this model. If empty, chat completions will be disabled. (default: alpaca)
# NOTE: Only works with chat completion message lists! # NOTE: Only works with chat completion message lists!
prompt_template: prompt_template: alpaca
# Number of experts to use per token. Loads from the model's config.json if not specified (default: None) # Number of experts to use per token. Loads from the model's config.json if not specified (default: None)
# WARNING: Don't set this unless you know what you're doing! # WARNING: Don't set this unless you know what you're doing!

19
main.py
View File

@@ -27,10 +27,10 @@ from OAI.utils import (
create_completion_response, create_completion_response,
get_model_list, get_model_list,
get_lora_list, get_lora_list,
get_chat_completion_prompt,
create_chat_completion_response, create_chat_completion_response,
create_chat_completion_stream_chunk create_chat_completion_stream_chunk
) )
from templating import get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap from utils import get_generator_error, get_sse_packet, load_progress, unwrap
app = FastAPI() app = FastAPI()
@@ -76,6 +76,7 @@ async def list_models():
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_current_model(): async def get_current_model():
model_name = model_container.get_model_path().name model_name = model_container.get_model_path().name
prompt_template = model_container.prompt_template
model_card = ModelCard( model_card = ModelCard(
id = model_name, id = model_name,
parameters = ModelCardParameters( parameters = ModelCardParameters(
@@ -83,7 +84,7 @@ async def get_current_model():
rope_alpha = model_container.config.scale_alpha_value, rope_alpha = model_container.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len, max_seq_len = model_container.config.max_seq_len,
cache_mode = "FP8" if model_container.cache_fp8 else "FP16", cache_mode = "FP8" if model_container.cache_fp8 else "FP16",
prompt_template = unwrap(model_container.prompt_template, "auto") prompt_template = prompt_template.name if prompt_template else None
), ),
logging = gen_logging.config logging = gen_logging.config
) )
@@ -302,19 +303,21 @@ async def generate_completion(request: Request, data: CompletionRequest):
# Chat completions endpoint # Chat completions endpoint
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def generate_chat_completion(request: Request, data: ChatCompletionRequest): async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
if model_container.prompt_template is None:
return HTTPException(422, "This endpoint is disabled because a prompt template is not set.")
model_path = model_container.get_model_path() model_path = model_container.get_model_path()
if isinstance(data.messages, str): if isinstance(data.messages, str):
prompt = data.messages prompt = data.messages
else: else:
# If the request specified prompt template isn't found, use the one from model container
# Otherwise, let fastchat figure it out
prompt_template = unwrap(data.prompt_template, model_container.prompt_template)
try: try:
prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template) prompt = get_prompt_from_template(data.messages, model_container.prompt_template)
except KeyError: except KeyError:
return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?") return HTTPException(
400,
f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?"
)
if data.stream: if data.stream:
const_id = f"chatcmpl-{uuid4().hex}" const_id = f"chatcmpl-{uuid4().hex}"

View File

@@ -17,6 +17,7 @@ from exllamav2.generator import(
from gen_logging import log_generation_params, log_prompt, log_response from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union from typing import List, Optional, Union
from templating import PromptTemplate
from utils import coalesce, unwrap from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split # Bytes to reserve on first device when loading with auto split
@@ -31,7 +32,7 @@ class ModelContainer:
draft_cache: Optional[ExLlamaV2Cache] = None draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2StreamingGenerator] = None generator: Optional[ExLlamaV2StreamingGenerator] = None
prompt_template: Optional[str] = None prompt_template: Optional[PromptTemplate] = None
cache_fp8: bool = False cache_fp8: bool = False
gpu_split_auto: bool = True gpu_split_auto: bool = True
@@ -103,7 +104,20 @@ class ModelContainer:
""" """
# Set prompt template override if provided # Set prompt template override if provided
self.prompt_template = kwargs.get("prompt_template") prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
try:
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template:
self.prompt_template = PromptTemplate(
name = prompt_template_name,
template = raw_template.read()
)
except OSError:
print("Chat completions are disabled because the provided prompt template couldn't be found.")
self.prompt_template = None
else:
print("Chat completions are disabled because a provided prompt template couldn't be found.")
self.prompt_template = None
# Set num of experts per token if provided # Set num of experts per token if provided
num_experts_override = kwargs.get("num_experts_per_token") num_experts_override = kwargs.get("num_experts_per_token")

View File

@@ -12,3 +12,4 @@ pydantic < 2,>= 1
PyYAML PyYAML
progress progress
uvicorn uvicorn
jinja2

View File

@@ -18,6 +18,7 @@ pydantic < 2,>= 1
PyYAML PyYAML
progress progress
uvicorn uvicorn
jinja2
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"

View File

@@ -18,6 +18,7 @@ pydantic < 2,>= 1
PyYAML PyYAML
progress progress
uvicorn uvicorn
jinja2
# Flash attention v2 # Flash attention v2

7
templates/README.md Normal file
View File

@@ -0,0 +1,7 @@
# Templates
NOTE: This folder will be replaced by a submodule or something similar in the future
These templates are examples from [Aphrodite Engine](https://github.com/PygmalionAI/aphrodite-engine/tree/main/examples)
Please look at [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for making Jinja2 templates.

29
templates/alpaca.jinja Normal file
View File

@@ -0,0 +1,29 @@
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
{% for message in messages %}
{% if message['role'] == 'user' %}
### Instruction:
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% elif message['role'] == 'assistant' %}
### Response:
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% elif message['role'] == 'user_context' %}
### Input:
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% endif %}
{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
### Response:
{% endif %}

2
templates/chatml.jinja Normal file
View File

@@ -0,0 +1,2 @@
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}

30
templating.py Normal file
View File

@@ -0,0 +1,30 @@
from functools import lru_cache
from importlib.metadata import version as package_version
from packaging import version
from jinja2.sandbox import ImmutableSandboxedEnvironment
from pydantic import BaseModel
# Small replication of AutoTokenizer's chat template system for efficiency
class PromptTemplate(BaseModel):
name: str
template: str
def get_prompt_from_template(messages, prompt_template: PromptTemplate):
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
f"Current version: {version('jinja2')}\n"
"Please upgrade fastchat by running the following command: "
"pip install -U fschat[model_worker]"
)
compiled_template = _compile_template(prompt_template.template)
return compiled_template.render(messages = messages)
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
@lru_cache
def _compile_template(template: str):
jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True)
jinja_template = jinja_env.from_string(template)
return jinja_template

View File

@@ -25,12 +25,12 @@ else:
print("Torch is not found in your environment.") print("Torch is not found in your environment.")
errored_packages.append("torch") errored_packages.append("torch")
if find_spec("fastchat") is not None: if find_spec("jinja2") is not None:
print(f"Fastchat on version {version('fschat')} successfully imported") print(f"Jinja2 on version {version('jinja2')} successfully imported")
successful_packages.append("fastchat") successful_packages.append("jinja2")
else: else:
print("Fastchat is not found in your environment. It isn't needed unless you're using chat completions with message arrays.") print("Jinja2 is not found in your environment.")
errored_packages.append("fastchat") errored_packages.append("jinja2")
print( print(
f"\nSuccessful imports: {', '.join(successful_packages)}", f"\nSuccessful imports: {', '.join(successful_packages)}",