From f631dd6ff7ad2407ad53b077db83e12d30eed8b6 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 17 Dec 2023 00:41:42 -0500 Subject: [PATCH] Templates: Switch to Jinja2 Jinja2 is a lightweight template parser that's used in Transformers for parsing chat completions. It's much more efficient than Fastchat and can be imported as part of requirements. Also allows for unblocking Pydantic's version. Users now have to provide their own template if needed. A separate repo may be usable for common prompt template storage. Signed-off-by: kingbri --- OAI/types/chat_completion.py | 4 +-- OAI/utils.py | 55 ++---------------------------------- README.md | 8 ++++-- config_sample.yml | 4 +-- main.py | 19 +++++++------ model.py | 18 ++++++++++-- requirements-amd.txt | 1 + requirements-cu118.txt | 1 + requirements.txt | 1 + templates/README.md | 7 +++++ templates/alpaca.jinja | 29 +++++++++++++++++++ templates/chatml.jinja | 2 ++ templating.py | 30 ++++++++++++++++++++ tests/wheel_test.py | 10 +++---- 14 files changed, 115 insertions(+), 74 deletions(-) create mode 100644 templates/README.md create mode 100644 templates/alpaca.jinja create mode 100644 templates/chatml.jinja create mode 100644 templating.py diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py index de1c935..e6b2b51 100644 --- a/OAI/types/chat_completion.py +++ b/OAI/types/chat_completion.py @@ -1,7 +1,7 @@ from uuid import uuid4 from time import time from pydantic import BaseModel, Field -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict from OAI.types.common import UsageStats, CommonCompletionRequest class ChatCompletionMessage(BaseModel): @@ -24,7 +24,7 @@ class ChatCompletionStreamChoice(BaseModel): class ChatCompletionRequest(CommonCompletionRequest): # Messages # Take in a string as well even though it's not part of the OAI spec - messages: Union[str, List[ChatCompletionMessage]] + messages: Union[str, List[Dict[str, str]]] prompt_template: Optional[str] = None class ChatCompletionResponse(BaseModel): diff --git a/OAI/utils.py b/OAI/utils.py index 650b44a..18e73ec 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -10,18 +10,9 @@ from OAI.types.chat_completion import ( from OAI.types.common import UsageStats from OAI.types.lora import LoraList, LoraCard from OAI.types.model import ModelList, ModelCard -from packaging import version -from typing import Optional, List -from utils import unwrap +from typing import Optional -# Check fastchat -try: - import fastchat - from fastchat.model.model_adapter import get_conversation_template, get_conv_template - from fastchat.conversation import SeparatorStyle - _fastchat_available = True -except ImportError: - _fastchat_available = False +from utils import unwrap def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]): choice = CompletionRespChoice( @@ -110,45 +101,3 @@ def get_lora_list(lora_path: pathlib.Path): lora_list.data.append(lora_card) return lora_list - -def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None): - - # TODO: Replace fastchat with in-house jinja templates - # Check if fastchat is available - if not _fastchat_available: - raise ModuleNotFoundError( - "Fastchat must be installed to parse these chat completion messages.\n" - "Please run the following command: pip install fschat[model_worker]" - ) - if version.parse(fastchat.__version__) < version.parse("0.2.23"): - raise ImportError( - "Parsing these chat completion messages requires fastchat 0.2.23 or greater. " - f"Current version: {fastchat.__version__}\n" - "Please upgrade fastchat by running the following command: " - "pip install -U fschat[model_worker]" - ) - - if prompt_template: - conv = get_conv_template(prompt_template) - else: - conv = get_conversation_template(model_path) - - if conv.sep_style is None: - conv.sep_style = SeparatorStyle.LLAMA2 - - for message in messages: - msg_role = message.role - if msg_role == "system": - conv.set_system_message(message.content) - elif msg_role == "user": - conv.append_message(conv.roles[0], message.content) - elif msg_role == "assistant": - conv.append_message(conv.roles[1], message.content) - else: - raise ValueError(f"Unknown role: {msg_role}") - - conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() - - print(prompt) - return prompt diff --git a/README.md b/README.md index ddd486b..88e9497 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,6 @@ NOTE: For Flash Attention 2 to work on Windows, CUDA 12.x **must** be installed! 3. ROCm 5.6: `pip install -r requirements-amd.txt` -5. If you want the `/v1/chat/completions` endpoint to work with a list of messages, install fastchat by running `pip install fschat[model_worker]` - ## Configuration A config.yml file is required for overriding project defaults. If you are okay with the defaults, you don't need a config file! @@ -126,6 +124,12 @@ All routes require an API key except for the following which require an **admin* - `/v1/model/unload` +## Chat Completions + +`/v1/chat/completions` now uses Jinja2 for templating. Please read [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for more information of how chat templates work. + +Also make sure to set the template name in `config.yml` to the template's filename. + ## Common Issues - AMD cards will error out with flash attention installed, even if the config option is set to False. Run `pip uninstall flash_attn` to remove the wheel from your system. diff --git a/config_sample.yml b/config_sample.yml index 3598585..25dfd90 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -56,9 +56,9 @@ model: # Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16) cache_mode: FP16 - # Set the prompt template for this model. If empty, fastchat will automatically find the best template to use (default: None) + # Set the prompt template for this model. If empty, chat completions will be disabled. (default: alpaca) # NOTE: Only works with chat completion message lists! - prompt_template: + prompt_template: alpaca # Number of experts to use per token. Loads from the model's config.json if not specified (default: None) # WARNING: Don't set this unless you know what you're doing! diff --git a/main.py b/main.py index b204be7..b55fe64 100644 --- a/main.py +++ b/main.py @@ -27,10 +27,10 @@ from OAI.utils import ( create_completion_response, get_model_list, get_lora_list, - get_chat_completion_prompt, create_chat_completion_response, create_chat_completion_stream_chunk ) +from templating import get_prompt_from_template from utils import get_generator_error, get_sse_packet, load_progress, unwrap app = FastAPI() @@ -76,6 +76,7 @@ async def list_models(): @app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def get_current_model(): model_name = model_container.get_model_path().name + prompt_template = model_container.prompt_template model_card = ModelCard( id = model_name, parameters = ModelCardParameters( @@ -83,7 +84,7 @@ async def get_current_model(): rope_alpha = model_container.config.scale_alpha_value, max_seq_len = model_container.config.max_seq_len, cache_mode = "FP8" if model_container.cache_fp8 else "FP16", - prompt_template = unwrap(model_container.prompt_template, "auto") + prompt_template = prompt_template.name if prompt_template else None ), logging = gen_logging.config ) @@ -302,19 +303,21 @@ async def generate_completion(request: Request, data: CompletionRequest): # Chat completions endpoint @app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def generate_chat_completion(request: Request, data: ChatCompletionRequest): + if model_container.prompt_template is None: + return HTTPException(422, "This endpoint is disabled because a prompt template is not set.") + model_path = model_container.get_model_path() if isinstance(data.messages, str): prompt = data.messages else: - # If the request specified prompt template isn't found, use the one from model container - # Otherwise, let fastchat figure it out - prompt_template = unwrap(data.prompt_template, model_container.prompt_template) - try: - prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template) + prompt = get_prompt_from_template(data.messages, model_container.prompt_template) except KeyError: - return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?") + return HTTPException( + 400, + f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?" + ) if data.stream: const_id = f"chatcmpl-{uuid4().hex}" diff --git a/model.py b/model.py index 2aed565..9d53812 100644 --- a/model.py +++ b/model.py @@ -17,6 +17,7 @@ from exllamav2.generator import( from gen_logging import log_generation_params, log_prompt, log_response from typing import List, Optional, Union +from templating import PromptTemplate from utils import coalesce, unwrap # Bytes to reserve on first device when loading with auto split @@ -31,7 +32,7 @@ class ModelContainer: draft_cache: Optional[ExLlamaV2Cache] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None generator: Optional[ExLlamaV2StreamingGenerator] = None - prompt_template: Optional[str] = None + prompt_template: Optional[PromptTemplate] = None cache_fp8: bool = False gpu_split_auto: bool = True @@ -103,7 +104,20 @@ class ModelContainer: """ # Set prompt template override if provided - self.prompt_template = kwargs.get("prompt_template") + prompt_template_name = kwargs.get("prompt_template") + if prompt_template_name: + try: + with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template: + self.prompt_template = PromptTemplate( + name = prompt_template_name, + template = raw_template.read() + ) + except OSError: + print("Chat completions are disabled because the provided prompt template couldn't be found.") + self.prompt_template = None + else: + print("Chat completions are disabled because a provided prompt template couldn't be found.") + self.prompt_template = None # Set num of experts per token if provided num_experts_override = kwargs.get("num_experts_per_token") diff --git a/requirements-amd.txt b/requirements-amd.txt index 3695758..00bc209 100644 --- a/requirements-amd.txt +++ b/requirements-amd.txt @@ -12,3 +12,4 @@ pydantic < 2,>= 1 PyYAML progress uvicorn +jinja2 diff --git a/requirements-cu118.txt b/requirements-cu118.txt index c09ee96..96d6a33 100644 --- a/requirements-cu118.txt +++ b/requirements-cu118.txt @@ -18,6 +18,7 @@ pydantic < 2,>= 1 PyYAML progress uvicorn +jinja2 # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" diff --git a/requirements.txt b/requirements.txt index 3e48c3f..5d8d51b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ pydantic < 2,>= 1 PyYAML progress uvicorn +jinja2 # Flash attention v2 diff --git a/templates/README.md b/templates/README.md new file mode 100644 index 0000000..a5720a1 --- /dev/null +++ b/templates/README.md @@ -0,0 +1,7 @@ +# Templates + +NOTE: This folder will be replaced by a submodule or something similar in the future + +These templates are examples from [Aphrodite Engine](https://github.com/PygmalionAI/aphrodite-engine/tree/main/examples) + +Please look at [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for making Jinja2 templates. diff --git a/templates/alpaca.jinja b/templates/alpaca.jinja new file mode 100644 index 0000000..45837b0 --- /dev/null +++ b/templates/alpaca.jinja @@ -0,0 +1,29 @@ +{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} + +{% for message in messages %} +{% if message['role'] == 'user' %} +### Instruction: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% elif message['role'] == 'assistant' %} +### Response: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% elif message['role'] == 'user_context' %} +### Input: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% endif %} +{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} +### Response: +{% endif %} diff --git a/templates/chatml.jinja b/templates/chatml.jinja new file mode 100644 index 0000000..18efd05 --- /dev/null +++ b/templates/chatml.jinja @@ -0,0 +1,2 @@ +{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %} diff --git a/templating.py b/templating.py new file mode 100644 index 0000000..9d35ddc --- /dev/null +++ b/templating.py @@ -0,0 +1,30 @@ +from functools import lru_cache +from importlib.metadata import version as package_version +from packaging import version +from jinja2.sandbox import ImmutableSandboxedEnvironment +from pydantic import BaseModel + +# Small replication of AutoTokenizer's chat template system for efficiency + +class PromptTemplate(BaseModel): + name: str + template: str + +def get_prompt_from_template(messages, prompt_template: PromptTemplate): + if version.parse(package_version("jinja2")) < version.parse("3.0.0"): + raise ImportError( + "Parsing these chat completion messages requires fastchat 0.2.23 or greater. " + f"Current version: {version('jinja2')}\n" + "Please upgrade fastchat by running the following command: " + "pip install -U fschat[model_worker]" + ) + + compiled_template = _compile_template(prompt_template.template) + return compiled_template.render(messages = messages) + +# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 +@lru_cache +def _compile_template(template: str): + jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True) + jinja_template = jinja_env.from_string(template) + return jinja_template diff --git a/tests/wheel_test.py b/tests/wheel_test.py index 150d66e..58343bf 100644 --- a/tests/wheel_test.py +++ b/tests/wheel_test.py @@ -25,12 +25,12 @@ else: print("Torch is not found in your environment.") errored_packages.append("torch") -if find_spec("fastchat") is not None: - print(f"Fastchat on version {version('fschat')} successfully imported") - successful_packages.append("fastchat") +if find_spec("jinja2") is not None: + print(f"Jinja2 on version {version('jinja2')} successfully imported") + successful_packages.append("jinja2") else: - print("Fastchat is not found in your environment. It isn't needed unless you're using chat completions with message arrays.") - errored_packages.append("fastchat") + print("Jinja2 is not found in your environment.") + errored_packages.append("jinja2") print( f"\nSuccessful imports: {', '.join(successful_packages)}",