Templating: Add generation prompt appending

Append generation prompts if given the flag on an OAI chat completion
request.

This appends the "assistant" message to the instruct prompt. Defaults
to true since this is intended behavior.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-17 21:34:31 -05:00
committed by Brian Dashore
parent 041070fd6e
commit de9a19b5d3
4 changed files with 14 additions and 6 deletions

View File

@@ -26,6 +26,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
# Take in a string as well even though it's not part of the OAI spec # Take in a string as well even though it's not part of the OAI spec
messages: Union[str, List[Dict[str, str]]] messages: Union[str, List[Dict[str, str]]]
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")

View File

@@ -56,9 +56,9 @@ model:
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16) # Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
cache_mode: FP16 cache_mode: FP16
# Set the prompt template for this model. If empty, chat completions will be disabled. (default: alpaca) # Set the prompt template for this model. If empty, chat completions will be disabled. (default: None)
# NOTE: Only works with chat completion message lists! # NOTE: Only works with chat completion message lists!
prompt_template: alpaca prompt_template:
# Number of experts to use per token. Loads from the model's config.json if not specified (default: None) # Number of experts to use per token. Loads from the model's config.json if not specified (default: None)
# WARNING: Don't set this unless you know what you're doing! # WARNING: Don't set this unless you know what you're doing!

View File

@@ -312,7 +312,11 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
prompt = data.messages prompt = data.messages
else: else:
try: try:
prompt = get_prompt_from_template(data.messages, model_container.prompt_template) prompt = get_prompt_from_template(
data.messages,
model_container.prompt_template,
data.add_generation_prompt
)
except KeyError: except KeyError:
return HTTPException( return HTTPException(
400, 400,

View File

@@ -11,7 +11,7 @@ class PromptTemplate(BaseModel):
name: str name: str
template: str template: str
def get_prompt_from_template(messages, prompt_template: PromptTemplate): def get_prompt_from_template(messages, prompt_template: PromptTemplate, add_generation_prompt: bool):
if version.parse(package_version("jinja2")) < version.parse("3.0.0"): if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError( raise ImportError(
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. " "Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
@@ -21,7 +21,10 @@ def get_prompt_from_template(messages, prompt_template: PromptTemplate):
) )
compiled_template = _compile_template(prompt_template.template) compiled_template = _compile_template(prompt_template.template)
return compiled_template.render(messages = messages) return compiled_template.render(
messages = messages,
add_generation_prompt = add_generation_prompt
)
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 # Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
@lru_cache @lru_cache
@@ -30,7 +33,7 @@ def _compile_template(template: str):
jinja_template = jinja_env.from_string(template) jinja_template = jinja_env.from_string(template)
return jinja_template return jinja_template
def get_template_from_file(prompt_template_name): def get_template_from_file(prompt_template_name: str):
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template: with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template:
return PromptTemplate( return PromptTemplate(
name = prompt_template_name, name = prompt_template_name,