diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py index de1c935..e6b2b51 100644 --- a/OAI/types/chat_completion.py +++ b/OAI/types/chat_completion.py @@ -1,7 +1,7 @@ from uuid import uuid4 from time import time from pydantic import BaseModel, Field -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict from OAI.types.common import UsageStats, CommonCompletionRequest class ChatCompletionMessage(BaseModel): @@ -24,7 +24,7 @@ class ChatCompletionStreamChoice(BaseModel): class ChatCompletionRequest(CommonCompletionRequest): # Messages # Take in a string as well even though it's not part of the OAI spec - messages: Union[str, List[ChatCompletionMessage]] + messages: Union[str, List[Dict[str, str]]] prompt_template: Optional[str] = None class ChatCompletionResponse(BaseModel): diff --git a/OAI/utils.py b/OAI/utils.py index 650b44a..18e73ec 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -10,18 +10,9 @@ from OAI.types.chat_completion import ( from OAI.types.common import UsageStats from OAI.types.lora import LoraList, LoraCard from OAI.types.model import ModelList, ModelCard -from packaging import version -from typing import Optional, List -from utils import unwrap +from typing import Optional -# Check fastchat -try: - import fastchat - from fastchat.model.model_adapter import get_conversation_template, get_conv_template - from fastchat.conversation import SeparatorStyle - _fastchat_available = True -except ImportError: - _fastchat_available = False +from utils import unwrap def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]): choice = CompletionRespChoice( @@ -110,45 +101,3 @@ def get_lora_list(lora_path: pathlib.Path): lora_list.data.append(lora_card) return lora_list - -def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None): - - # TODO: Replace fastchat with in-house jinja templates - # Check if fastchat is available - if not _fastchat_available: - raise ModuleNotFoundError( - "Fastchat must be installed to parse these chat completion messages.\n" - "Please run the following command: pip install fschat[model_worker]" - ) - if version.parse(fastchat.__version__) < version.parse("0.2.23"): - raise ImportError( - "Parsing these chat completion messages requires fastchat 0.2.23 or greater. " - f"Current version: {fastchat.__version__}\n" - "Please upgrade fastchat by running the following command: " - "pip install -U fschat[model_worker]" - ) - - if prompt_template: - conv = get_conv_template(prompt_template) - else: - conv = get_conversation_template(model_path) - - if conv.sep_style is None: - conv.sep_style = SeparatorStyle.LLAMA2 - - for message in messages: - msg_role = message.role - if msg_role == "system": - conv.set_system_message(message.content) - elif msg_role == "user": - conv.append_message(conv.roles[0], message.content) - elif msg_role == "assistant": - conv.append_message(conv.roles[1], message.content) - else: - raise ValueError(f"Unknown role: {msg_role}") - - conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() - - print(prompt) - return prompt diff --git a/README.md b/README.md index ddd486b..88e9497 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,6 @@ NOTE: For Flash Attention 2 to work on Windows, CUDA 12.x **must** be installed! 3. ROCm 5.6: `pip install -r requirements-amd.txt` -5. If you want the `/v1/chat/completions` endpoint to work with a list of messages, install fastchat by running `pip install fschat[model_worker]` - ## Configuration A config.yml file is required for overriding project defaults. If you are okay with the defaults, you don't need a config file! @@ -126,6 +124,12 @@ All routes require an API key except for the following which require an **admin* - `/v1/model/unload` +## Chat Completions + +`/v1/chat/completions` now uses Jinja2 for templating. Please read [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for more information of how chat templates work. + +Also make sure to set the template name in `config.yml` to the template's filename. + ## Common Issues - AMD cards will error out with flash attention installed, even if the config option is set to False. Run `pip uninstall flash_attn` to remove the wheel from your system. diff --git a/config_sample.yml b/config_sample.yml index 3598585..25dfd90 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -56,9 +56,9 @@ model: # Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16) cache_mode: FP16 - # Set the prompt template for this model. If empty, fastchat will automatically find the best template to use (default: None) + # Set the prompt template for this model. If empty, chat completions will be disabled. (default: alpaca) # NOTE: Only works with chat completion message lists! - prompt_template: + prompt_template: alpaca # Number of experts to use per token. Loads from the model's config.json if not specified (default: None) # WARNING: Don't set this unless you know what you're doing! diff --git a/main.py b/main.py index b204be7..b55fe64 100644 --- a/main.py +++ b/main.py @@ -27,10 +27,10 @@ from OAI.utils import ( create_completion_response, get_model_list, get_lora_list, - get_chat_completion_prompt, create_chat_completion_response, create_chat_completion_stream_chunk ) +from templating import get_prompt_from_template from utils import get_generator_error, get_sse_packet, load_progress, unwrap app = FastAPI() @@ -76,6 +76,7 @@ async def list_models(): @app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def get_current_model(): model_name = model_container.get_model_path().name + prompt_template = model_container.prompt_template model_card = ModelCard( id = model_name, parameters = ModelCardParameters( @@ -83,7 +84,7 @@ async def get_current_model(): rope_alpha = model_container.config.scale_alpha_value, max_seq_len = model_container.config.max_seq_len, cache_mode = "FP8" if model_container.cache_fp8 else "FP16", - prompt_template = unwrap(model_container.prompt_template, "auto") + prompt_template = prompt_template.name if prompt_template else None ), logging = gen_logging.config ) @@ -302,19 +303,21 @@ async def generate_completion(request: Request, data: CompletionRequest): # Chat completions endpoint @app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def generate_chat_completion(request: Request, data: ChatCompletionRequest): + if model_container.prompt_template is None: + return HTTPException(422, "This endpoint is disabled because a prompt template is not set.") + model_path = model_container.get_model_path() if isinstance(data.messages, str): prompt = data.messages else: - # If the request specified prompt template isn't found, use the one from model container - # Otherwise, let fastchat figure it out - prompt_template = unwrap(data.prompt_template, model_container.prompt_template) - try: - prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template) + prompt = get_prompt_from_template(data.messages, model_container.prompt_template) except KeyError: - return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?") + return HTTPException( + 400, + f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?" + ) if data.stream: const_id = f"chatcmpl-{uuid4().hex}" diff --git a/model.py b/model.py index 2aed565..9d53812 100644 --- a/model.py +++ b/model.py @@ -17,6 +17,7 @@ from exllamav2.generator import( from gen_logging import log_generation_params, log_prompt, log_response from typing import List, Optional, Union +from templating import PromptTemplate from utils import coalesce, unwrap # Bytes to reserve on first device when loading with auto split @@ -31,7 +32,7 @@ class ModelContainer: draft_cache: Optional[ExLlamaV2Cache] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None generator: Optional[ExLlamaV2StreamingGenerator] = None - prompt_template: Optional[str] = None + prompt_template: Optional[PromptTemplate] = None cache_fp8: bool = False gpu_split_auto: bool = True @@ -103,7 +104,20 @@ class ModelContainer: """ # Set prompt template override if provided - self.prompt_template = kwargs.get("prompt_template") + prompt_template_name = kwargs.get("prompt_template") + if prompt_template_name: + try: + with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template: + self.prompt_template = PromptTemplate( + name = prompt_template_name, + template = raw_template.read() + ) + except OSError: + print("Chat completions are disabled because the provided prompt template couldn't be found.") + self.prompt_template = None + else: + print("Chat completions are disabled because a provided prompt template couldn't be found.") + self.prompt_template = None # Set num of experts per token if provided num_experts_override = kwargs.get("num_experts_per_token") diff --git a/requirements-amd.txt b/requirements-amd.txt index 3695758..00bc209 100644 --- a/requirements-amd.txt +++ b/requirements-amd.txt @@ -12,3 +12,4 @@ pydantic < 2,>= 1 PyYAML progress uvicorn +jinja2 diff --git a/requirements-cu118.txt b/requirements-cu118.txt index c09ee96..96d6a33 100644 --- a/requirements-cu118.txt +++ b/requirements-cu118.txt @@ -18,6 +18,7 @@ pydantic < 2,>= 1 PyYAML progress uvicorn +jinja2 # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" diff --git a/requirements.txt b/requirements.txt index 3e48c3f..5d8d51b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ pydantic < 2,>= 1 PyYAML progress uvicorn +jinja2 # Flash attention v2 diff --git a/templates/README.md b/templates/README.md new file mode 100644 index 0000000..a5720a1 --- /dev/null +++ b/templates/README.md @@ -0,0 +1,7 @@ +# Templates + +NOTE: This folder will be replaced by a submodule or something similar in the future + +These templates are examples from [Aphrodite Engine](https://github.com/PygmalionAI/aphrodite-engine/tree/main/examples) + +Please look at [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for making Jinja2 templates. diff --git a/templates/alpaca.jinja b/templates/alpaca.jinja new file mode 100644 index 0000000..45837b0 --- /dev/null +++ b/templates/alpaca.jinja @@ -0,0 +1,29 @@ +{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} + +{% for message in messages %} +{% if message['role'] == 'user' %} +### Instruction: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% elif message['role'] == 'assistant' %} +### Response: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% elif message['role'] == 'user_context' %} +### Input: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% endif %} +{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} +### Response: +{% endif %} diff --git a/templates/chatml.jinja b/templates/chatml.jinja new file mode 100644 index 0000000..18efd05 --- /dev/null +++ b/templates/chatml.jinja @@ -0,0 +1,2 @@ +{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %} diff --git a/templating.py b/templating.py new file mode 100644 index 0000000..9d35ddc --- /dev/null +++ b/templating.py @@ -0,0 +1,30 @@ +from functools import lru_cache +from importlib.metadata import version as package_version +from packaging import version +from jinja2.sandbox import ImmutableSandboxedEnvironment +from pydantic import BaseModel + +# Small replication of AutoTokenizer's chat template system for efficiency + +class PromptTemplate(BaseModel): + name: str + template: str + +def get_prompt_from_template(messages, prompt_template: PromptTemplate): + if version.parse(package_version("jinja2")) < version.parse("3.0.0"): + raise ImportError( + "Parsing these chat completion messages requires fastchat 0.2.23 or greater. " + f"Current version: {version('jinja2')}\n" + "Please upgrade fastchat by running the following command: " + "pip install -U fschat[model_worker]" + ) + + compiled_template = _compile_template(prompt_template.template) + return compiled_template.render(messages = messages) + +# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 +@lru_cache +def _compile_template(template: str): + jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True) + jinja_template = jinja_env.from_string(template) + return jinja_template diff --git a/tests/wheel_test.py b/tests/wheel_test.py index 150d66e..58343bf 100644 --- a/tests/wheel_test.py +++ b/tests/wheel_test.py @@ -25,12 +25,12 @@ else: print("Torch is not found in your environment.") errored_packages.append("torch") -if find_spec("fastchat") is not None: - print(f"Fastchat on version {version('fschat')} successfully imported") - successful_packages.append("fastchat") +if find_spec("jinja2") is not None: + print(f"Jinja2 on version {version('jinja2')} successfully imported") + successful_packages.append("jinja2") else: - print("Fastchat is not found in your environment. It isn't needed unless you're using chat completions with message arrays.") - errored_packages.append("fastchat") + print("Jinja2 is not found in your environment.") + errored_packages.append("jinja2") print( f"\nSuccessful imports: {', '.join(successful_packages)}",