mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-29 02:31:48 +00:00
OAI: Tokenize chat completion messages
Since chat completion messages are a structure, format the prompt before checking in the tokenizer. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -566,7 +566,9 @@ class ExllamaV2Container:
|
|||||||
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
|
def get_special_tokens(
|
||||||
|
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
||||||
|
):
|
||||||
return {
|
return {
|
||||||
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
|
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
|
||||||
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
|
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from common.concurrency import (
|
|||||||
from common.networking import handle_request_error, run_with_request_disconnect
|
from common.networking import handle_request_error, run_with_request_disconnect
|
||||||
from common.templating import (
|
from common.templating import (
|
||||||
get_all_templates,
|
get_all_templates,
|
||||||
|
get_prompt_from_template,
|
||||||
get_template_from_file,
|
get_template_from_file,
|
||||||
)
|
)
|
||||||
from common.utils import coalesce, unwrap
|
from common.utils import coalesce, unwrap
|
||||||
@@ -386,8 +387,26 @@ async def unload_loras():
|
|||||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||||
)
|
)
|
||||||
async def encode_tokens(data: TokenEncodeRequest):
|
async def encode_tokens(data: TokenEncodeRequest):
|
||||||
"""Encodes a string into tokens."""
|
"""Encodes a string or chat completion messages into tokens."""
|
||||||
raw_tokens = model.container.encode_tokens(data.text, **data.get_params())
|
|
||||||
|
if isinstance(data.text, str):
|
||||||
|
text = data.text
|
||||||
|
else:
|
||||||
|
special_tokens_dict = model.container.get_special_tokens(
|
||||||
|
unwrap(data.add_bos_token, True)
|
||||||
|
)
|
||||||
|
|
||||||
|
template_vars = {
|
||||||
|
"messages": data.text,
|
||||||
|
"add_generation_prompt": False,
|
||||||
|
**special_tokens_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
text, _ = get_prompt_from_template(
|
||||||
|
model.container.prompt_template, template_vars
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_tokens = model.container.encode_tokens(text, **data.get_params())
|
||||||
tokens = unwrap(raw_tokens, [])
|
tokens = unwrap(raw_tokens, [])
|
||||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Tokenization types"""
|
"""Tokenization types"""
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
|
||||||
class CommonTokenRequest(BaseModel):
|
class CommonTokenRequest(BaseModel):
|
||||||
@@ -23,7 +23,7 @@ class CommonTokenRequest(BaseModel):
|
|||||||
class TokenEncodeRequest(CommonTokenRequest):
|
class TokenEncodeRequest(CommonTokenRequest):
|
||||||
"""Represents a tokenization request."""
|
"""Represents a tokenization request."""
|
||||||
|
|
||||||
text: str
|
text: Union[str, List[Dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
class TokenEncodeResponse(BaseModel):
|
class TokenEncodeResponse(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user