diff --git a/endpoints/OAI/app.py b/endpoints/OAI/app.py index d6690ee..6c523d2 100644 --- a/endpoints/OAI/app.py +++ b/endpoints/OAI/app.py @@ -508,10 +508,7 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) if isinstance(data.messages, str): prompt = data.messages else: - # Compile the prompt and get any additional stop strings from the template - # Template stop strings can be overriden by sampler overrides if force is true - prompt, template_stop_strings = format_prompt_with_template(data) - data.stop += template_stop_strings + prompt = format_prompt_with_template(data) disable_request_streaming = unwrap( config.developer_config().get("disable_request_streaming"), False diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index d74db83..6f18c9c 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -130,18 +130,32 @@ def _create_stream_chunk( def format_prompt_with_template(data: ChatCompletionRequest): + """ + Compile the prompt and get any additional stop strings from the template. + Template stop strings can be overriden by sampler overrides if force is true. + """ + try: special_tokens_dict = model.container.get_special_tokens( unwrap(data.add_bos_token, True), unwrap(data.ban_eos_token, False), ) - return get_prompt_from_template( + prompt, template_stop_strings = get_prompt_from_template( data.messages, model.container.prompt_template, data.add_generation_prompt, special_tokens_dict, ) + + # Append template stop strings + if isinstance(data.stop, str): + data.stop = [data.stop] + template_stop_strings + else: + data.stop += template_stop_strings + + return prompt + except KeyError as exc: raise HTTPException( 400,