Api: Add ban_eos_token and add_bos_token support

Adds the ability for the client to specify whether to add the BOS
token and ban the EOS token.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-14 23:05:47 -05:00
parent 8fea5391a8
commit ea91d17a11
3 changed files with 19 additions and 4 deletions

View File

@@ -200,7 +200,8 @@ class ModelContainer:
if text:
# Assume token encoding
return self.tokenizer.encode(
text, add_bos = kwargs.get("add_bos", True),
text,
add_bos = kwargs.get("add_bos_token", True),
encode_special_tokens = kwargs.get("encode_special_tokens", True)
)
if ids:
@@ -236,6 +237,8 @@ class ModelContainer:
'repetition_decay' (int): Repetition penalty range (default: same as range)
'stop' (list): List of stop strings/tokens to end response (default: [EOS])
'max_tokens' (int): Max no. tokens in response (default: 150)
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
'generate_window' (int): Space to reserve at the end of the model's context when generating.
Rolls context window by the same amount if context length is exceeded to allow generating past
@@ -266,6 +269,10 @@ class ModelContainer:
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
# Ban the EOS token if specified
if kwargs.get("ban_eos_token", False):
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
# Override sampler settings for temp = 0
if gen_settings.temperature == 0:
@@ -280,7 +287,11 @@ class ModelContainer:
# Tokenized context
ids = self.tokenizer.encode(prompt, encode_special_tokens = True)
ids = self.tokenizer.encode(
prompt,
add_bos=kwargs.get("add_bos_token", True),
encode_special_tokens = True
)
# Begin