Add Llama3 template

This commit is contained in:
turboderp
2024-04-19 17:41:01 +02:00
parent 0e919cb1a1
commit b33d45d45b
2 changed files with 52 additions and 0 deletions

View File

@@ -16,6 +16,9 @@ class PromptFormat:
def encode_special_tokens(self):
return True
def context_bos(self):
return False
@staticmethod
def supports_system_prompt():
return True
@@ -71,6 +74,42 @@ class PromptFormat_llama(PromptFormat):
return text
class PromptFormat_llama3(PromptFormat):
description = "Llama-3 instruct template.chat"
def __init__(self):
super().__init__()
pass
def is_instruct(self):
return True
def stop_conditions(self, tokenizer, settings):
return \
[tokenizer.single_id("<|eot_id|>"),
tokenizer.single_id("<|start_header_id|>"),
tokenizer.eos_token_id]
def format(self, prompt, response, system_prompt, settings):
text = ""
if system_prompt and system_prompt.strip() != "":
text += "<|start_header_id|>system<|end_header_id|>\n\n"
text += system_prompt
text += "<|eot_id|>"
text += "<|start_header_id|>user<|end_header_id|>\n\n"
text += prompt
text += "<|eot_id|>"
text += "<|start_header_id|>assistant<|end_header_id|>\n\n"
if response:
text += response
text += "<|eot_id|>"
return text
def context_bos(self):
return True
class PromptFormat_mistrallite(PromptFormat):
description = "MistralLite format"
@@ -353,6 +392,7 @@ prompt_formats = \
{
"Chat-RP": PromptFormat_raw,
"Llama-chat": PromptFormat_llama,
"Llama3-instruct": PromptFormat_llama3,
"ChatML": PromptFormat_chatml,
"TinyLlama-chat": PromptFormat_tinyllama,
"MistralLite": PromptFormat_mistrallite,

View File

@@ -215,6 +215,11 @@ class Session:
prompts = []
responses = []
# Make room for one-off BOS token
if prompt_format.context_bos():
max_len -= 1
# Prepare prefix
prefix_ids = None
@@ -295,8 +300,15 @@ class Session:
context_str += " " + prefix
context_ids = torch.cat([context_ids, prefix_ids], dim = -1)
# Add context BOS
if prompt_format.context_bos():
context_str = tokenizer.bos_token + context_str
context_ids = torch.cat([tokenizer.single_token(tokenizer.bos_token_id), context_ids], dim = -1)
# print("self.history_first", self.history_first)
# print("context_ids.shape[-1]", context_ids.shape[-1])
return context_str, context_ids