Templates: Support bos_token and eos_token fields

These are commonly seen in huggingface provided chat templates and
aren't that difficult to add in.

For feature parity, honor the add_bos_token and ban_eos_token
parameters when constructing the prompt.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-22 10:31:50 -05:00
parent 2bf8087de3
commit a14abfe21c
3 changed files with 21 additions and 3 deletions

View File

@@ -341,6 +341,14 @@ class ModelContainer:
ids = torch.tensor([ids])
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
return {
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
"pad_token": self.tokenizer.pad_token,
"unk_token": self.tokenizer.unk_token,
}
def generate(self, prompt: str, **kwargs):
generation = list(self.generate_gen(prompt, **kwargs))
if generation: