From de9a19b5d3777090b6a279daf2b0f42ebf356942 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 17 Dec 2023 21:34:31 -0500 Subject: [PATCH] Templating: Add generation prompt appending Append generation prompts if given the flag on an OAI chat completion request. This appends the "assistant" message to the instruct prompt. Defaults to true since this is intended behavior. Signed-off-by: kingbri --- OAI/types/chat_completion.py | 1 + config_sample.yml | 4 ++-- main.py | 6 +++++- templating.py | 9 ++++++--- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py index 2bb3cc9..891f548 100644 --- a/OAI/types/chat_completion.py +++ b/OAI/types/chat_completion.py @@ -26,6 +26,7 @@ class ChatCompletionRequest(CommonCompletionRequest): # Take in a string as well even though it's not part of the OAI spec messages: Union[str, List[Dict[str, str]]] prompt_template: Optional[str] = None + add_generation_prompt: Optional[bool] = True class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") diff --git a/config_sample.yml b/config_sample.yml index 25dfd90..15ce81b 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -56,9 +56,9 @@ model: # Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16) cache_mode: FP16 - # Set the prompt template for this model. If empty, chat completions will be disabled. (default: alpaca) + # Set the prompt template for this model. If empty, chat completions will be disabled. (default: None) # NOTE: Only works with chat completion message lists! - prompt_template: alpaca + prompt_template: # Number of experts to use per token. Loads from the model's config.json if not specified (default: None) # WARNING: Don't set this unless you know what you're doing! diff --git a/main.py b/main.py index 4687459..c9188d5 100644 --- a/main.py +++ b/main.py @@ -312,7 +312,11 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest prompt = data.messages else: try: - prompt = get_prompt_from_template(data.messages, model_container.prompt_template) + prompt = get_prompt_from_template( + data.messages, + model_container.prompt_template, + data.add_generation_prompt + ) except KeyError: return HTTPException( 400, diff --git a/templating.py b/templating.py index f8d074f..81121eb 100644 --- a/templating.py +++ b/templating.py @@ -11,7 +11,7 @@ class PromptTemplate(BaseModel): name: str template: str -def get_prompt_from_template(messages, prompt_template: PromptTemplate): +def get_prompt_from_template(messages, prompt_template: PromptTemplate, add_generation_prompt: bool): if version.parse(package_version("jinja2")) < version.parse("3.0.0"): raise ImportError( "Parsing these chat completion messages requires fastchat 0.2.23 or greater. " @@ -21,7 +21,10 @@ def get_prompt_from_template(messages, prompt_template: PromptTemplate): ) compiled_template = _compile_template(prompt_template.template) - return compiled_template.render(messages = messages) + return compiled_template.render( + messages = messages, + add_generation_prompt = add_generation_prompt + ) # Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 @lru_cache @@ -30,7 +33,7 @@ def _compile_template(template: str): jinja_template = jinja_env.from_string(template) return jinja_template -def get_template_from_file(prompt_template_name): +def get_template_from_file(prompt_template_name: str): with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template: return PromptTemplate( name = prompt_template_name,