From e8b6a02aa8f5027002a29151c7c45f9092da733d Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 29 Mar 2024 02:24:13 -0400 Subject: [PATCH] API: Move prompt template construction to utils Best to move the inner workings within its inner function. Also fix an edge case where stop strings can be a string rather than an array. Signed-off-by: kingbri --- endpoints/OAI/app.py | 5 +---- endpoints/OAI/utils/chat_completion.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/endpoints/OAI/app.py b/endpoints/OAI/app.py index d6690ee..6c523d2 100644 --- a/endpoints/OAI/app.py +++ b/endpoints/OAI/app.py @@ -508,10 +508,7 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) if isinstance(data.messages, str): prompt = data.messages else: - # Compile the prompt and get any additional stop strings from the template - # Template stop strings can be overriden by sampler overrides if force is true - prompt, template_stop_strings = format_prompt_with_template(data) - data.stop += template_stop_strings + prompt = format_prompt_with_template(data) disable_request_streaming = unwrap( config.developer_config().get("disable_request_streaming"), False diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index d74db83..6f18c9c 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -130,18 +130,32 @@ def _create_stream_chunk( def format_prompt_with_template(data: ChatCompletionRequest): + """ + Compile the prompt and get any additional stop strings from the template. + Template stop strings can be overriden by sampler overrides if force is true. + """ + try: special_tokens_dict = model.container.get_special_tokens( unwrap(data.add_bos_token, True), unwrap(data.ban_eos_token, False), ) - return get_prompt_from_template( + prompt, template_stop_strings = get_prompt_from_template( data.messages, model.container.prompt_template, data.add_generation_prompt, special_tokens_dict, ) + + # Append template stop strings + if isinstance(data.stop, str): + data.stop = [data.stop] + template_stop_strings + else: + data.stop += template_stop_strings + + return prompt + except KeyError as exc: raise HTTPException( 400,