diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py index 2bb3cc9..891f548 100644 --- a/OAI/types/chat_completion.py +++ b/OAI/types/chat_completion.py @@ -26,6 +26,7 @@ class ChatCompletionRequest(CommonCompletionRequest): # Take in a string as well even though it's not part of the OAI spec messages: Union[str, List[Dict[str, str]]] prompt_template: Optional[str] = None + add_generation_prompt: Optional[bool] = True class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") diff --git a/config_sample.yml b/config_sample.yml index 25dfd90..15ce81b 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -56,9 +56,9 @@ model: # Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16) cache_mode: FP16 - # Set the prompt template for this model. If empty, chat completions will be disabled. (default: alpaca) + # Set the prompt template for this model. If empty, chat completions will be disabled. (default: None) # NOTE: Only works with chat completion message lists! - prompt_template: alpaca + prompt_template: # Number of experts to use per token. Loads from the model's config.json if not specified (default: None) # WARNING: Don't set this unless you know what you're doing! diff --git a/main.py b/main.py index 4687459..c9188d5 100644 --- a/main.py +++ b/main.py @@ -312,7 +312,11 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest prompt = data.messages else: try: - prompt = get_prompt_from_template(data.messages, model_container.prompt_template) + prompt = get_prompt_from_template( + data.messages, + model_container.prompt_template, + data.add_generation_prompt + ) except KeyError: return HTTPException( 400, diff --git a/templating.py b/templating.py index f8d074f..81121eb 100644 --- a/templating.py +++ b/templating.py @@ -11,7 +11,7 @@ class PromptTemplate(BaseModel): name: str template: str -def get_prompt_from_template(messages, prompt_template: PromptTemplate): +def get_prompt_from_template(messages, prompt_template: PromptTemplate, add_generation_prompt: bool): if version.parse(package_version("jinja2")) < version.parse("3.0.0"): raise ImportError( "Parsing these chat completion messages requires fastchat 0.2.23 or greater. " @@ -21,7 +21,10 @@ def get_prompt_from_template(messages, prompt_template: PromptTemplate): ) compiled_template = _compile_template(prompt_template.template) - return compiled_template.render(messages = messages) + return compiled_template.render( + messages = messages, + add_generation_prompt = add_generation_prompt + ) # Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 @lru_cache @@ -30,7 +33,7 @@ def _compile_template(template: str): jinja_template = jinja_env.from_string(template) return jinja_template -def get_template_from_file(prompt_template_name): +def get_template_from_file(prompt_template_name: str): with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template: return PromptTemplate( name = prompt_template_name,