OAI: Add response_prefix and fix BOS token issues in chat completions

response_prefix is used to add a prefix before generating the next
message. This is used in many cases such as continuining a prompt
(see #96).

Also if a template has BOS token specified, add_bos_token will
append two BOS tokens. Add a check which strips a starting BOS token
from the prompt if it exists.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-04-25 00:54:43 -04:00
parent ed7cd3cb59
commit fb1d2f34c1
4 changed files with 20 additions and 1 deletions

View File

@@ -878,6 +878,7 @@ class ExllamaV2Container:
encode_special_tokens=True, encode_special_tokens=True,
return_offsets=True, return_offsets=True,
) )
print(ids)
mask = ( mask = (
self.tokenizer.padding_mask(ids) self.tokenizer.padding_mask(ids)
if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0] if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0]

View File

@@ -45,6 +45,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {} template_vars: Optional[dict] = {}
response_prefix: Optional[str] = None
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):

View File

@@ -8,6 +8,7 @@ from uuid import uuid4
from fastapi import HTTPException from fastapi import HTTPException
from jinja2 import TemplateError from jinja2 import TemplateError
from loguru import logger
from common import model from common import model
from common.networking import ( from common.networking import (
@@ -153,6 +154,22 @@ def format_prompt_with_template(data: ChatCompletionRequest):
data.template_vars data.template_vars
) )
# Append response prefix if present
if data.response_prefix:
if data.add_generation_prompt:
prompt += data.response_prefix
else:
logger.warning(
"Could not add response prefix because "
"add_generation_prompt is False"
)
# Removes the starting BOS token if present
# This is to prevent add_bos_token from adding multiple bos tokens
bos_token = special_tokens_dict.get("bos_token")
if bos_token and prompt.startswith(bos_token):
prompt = prompt.removeprefix(bos_token)
# Append template stop strings # Append template stop strings
if isinstance(data.stop, str): if isinstance(data.stop, str):
data.stop = [data.stop] + template_stop_strings data.stop = [data.stop] + template_stop_strings

View File

@@ -94,8 +94,8 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
try: try:
generation = await model.container.generate(data.prompt, **data.to_gen_params()) generation = await model.container.generate(data.prompt, **data.to_gen_params())
response = _create_response(generation, model_path.name) response = _create_response(generation, model_path.name)
return response return response
except Exception as exc: except Exception as exc:
error_message = handle_request_error( error_message = handle_request_error(