mirror of
https://github.com/turboderp-org/exui.git
synced 2026-04-19 22:08:58 +00:00
Add Llama3 template
This commit is contained in:
@@ -16,6 +16,9 @@ class PromptFormat:
|
||||
def encode_special_tokens(self):
|
||||
return True
|
||||
|
||||
def context_bos(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def supports_system_prompt():
|
||||
return True
|
||||
@@ -71,6 +74,42 @@ class PromptFormat_llama(PromptFormat):
|
||||
return text
|
||||
|
||||
|
||||
class PromptFormat_llama3(PromptFormat):
|
||||
|
||||
description = "Llama-3 instruct template.chat"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
def is_instruct(self):
|
||||
return True
|
||||
|
||||
def stop_conditions(self, tokenizer, settings):
|
||||
return \
|
||||
[tokenizer.single_id("<|eot_id|>"),
|
||||
tokenizer.single_id("<|start_header_id|>"),
|
||||
tokenizer.eos_token_id]
|
||||
|
||||
def format(self, prompt, response, system_prompt, settings):
|
||||
text = ""
|
||||
if system_prompt and system_prompt.strip() != "":
|
||||
text += "<|start_header_id|>system<|end_header_id|>\n\n"
|
||||
text += system_prompt
|
||||
text += "<|eot_id|>"
|
||||
text += "<|start_header_id|>user<|end_header_id|>\n\n"
|
||||
text += prompt
|
||||
text += "<|eot_id|>"
|
||||
text += "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
if response:
|
||||
text += response
|
||||
text += "<|eot_id|>"
|
||||
return text
|
||||
|
||||
def context_bos(self):
|
||||
return True
|
||||
|
||||
|
||||
class PromptFormat_mistrallite(PromptFormat):
|
||||
|
||||
description = "MistralLite format"
|
||||
@@ -353,6 +392,7 @@ prompt_formats = \
|
||||
{
|
||||
"Chat-RP": PromptFormat_raw,
|
||||
"Llama-chat": PromptFormat_llama,
|
||||
"Llama3-instruct": PromptFormat_llama3,
|
||||
"ChatML": PromptFormat_chatml,
|
||||
"TinyLlama-chat": PromptFormat_tinyllama,
|
||||
"MistralLite": PromptFormat_mistrallite,
|
||||
|
||||
@@ -215,6 +215,11 @@ class Session:
|
||||
prompts = []
|
||||
responses = []
|
||||
|
||||
# Make room for one-off BOS token
|
||||
|
||||
if prompt_format.context_bos():
|
||||
max_len -= 1
|
||||
|
||||
# Prepare prefix
|
||||
|
||||
prefix_ids = None
|
||||
@@ -295,8 +300,15 @@ class Session:
|
||||
context_str += " " + prefix
|
||||
context_ids = torch.cat([context_ids, prefix_ids], dim = -1)
|
||||
|
||||
# Add context BOS
|
||||
|
||||
if prompt_format.context_bos():
|
||||
context_str = tokenizer.bos_token + context_str
|
||||
context_ids = torch.cat([tokenizer.single_token(tokenizer.bos_token_id), context_ids], dim = -1)
|
||||
|
||||
# print("self.history_first", self.history_first)
|
||||
# print("context_ids.shape[-1]", context_ids.shape[-1])
|
||||
|
||||
return context_str, context_ids
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user