From 515b3c2930f46be5d4fa7b9282466fad80cc642a Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 15 Apr 2024 14:17:16 -0400 Subject: [PATCH] OAI: Tokenize chat completion messages Since chat completion messages are a structure, format the prompt before checking in the tokenizer. Signed-off-by: kingbri --- backends/exllamav2/model.py | 4 +++- endpoints/OAI/router.py | 23 +++++++++++++++++++++-- endpoints/OAI/types/token.py | 4 ++-- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index c8cc6e5..b610d8e 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -566,7 +566,9 @@ class ExllamaV2Container: decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True), )[0] - def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool): + def get_special_tokens( + self, add_bos_token: bool = True, ban_eos_token: bool = False + ): return { "bos_token": self.tokenizer.bos_token if add_bos_token else "", "eos_token": self.tokenizer.eos_token if not ban_eos_token else "", diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 833945c..e2fec3f 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -16,6 +16,7 @@ from common.concurrency import ( from common.networking import handle_request_error, run_with_request_disconnect from common.templating import ( get_all_templates, + get_prompt_from_template, get_template_from_file, ) from common.utils import coalesce, unwrap @@ -386,8 +387,26 @@ async def unload_loras(): dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def encode_tokens(data: TokenEncodeRequest): - """Encodes a string into tokens.""" - raw_tokens = model.container.encode_tokens(data.text, **data.get_params()) + """Encodes a string or chat completion messages into tokens.""" + + if isinstance(data.text, str): + text = data.text + else: + special_tokens_dict = model.container.get_special_tokens( + unwrap(data.add_bos_token, True) + ) + + template_vars = { + "messages": data.text, + "add_generation_prompt": False, + **special_tokens_dict, + } + + text, _ = get_prompt_from_template( + model.container.prompt_template, template_vars + ) + + raw_tokens = model.container.encode_tokens(text, **data.get_params()) tokens = unwrap(raw_tokens, []) response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) diff --git a/endpoints/OAI/types/token.py b/endpoints/OAI/types/token.py index cda1cf1..945adbf 100644 --- a/endpoints/OAI/types/token.py +++ b/endpoints/OAI/types/token.py @@ -1,7 +1,7 @@ """Tokenization types""" from pydantic import BaseModel -from typing import List +from typing import Dict, List, Union class CommonTokenRequest(BaseModel): @@ -23,7 +23,7 @@ class CommonTokenRequest(BaseModel): class TokenEncodeRequest(CommonTokenRequest): """Represents a tokenization request.""" - text: str + text: Union[str, List[Dict[str, str]]] class TokenEncodeResponse(BaseModel):