From ea91d17a11820cb356bae369a214af4176966258 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 14 Nov 2023 23:05:47 -0500 Subject: [PATCH] Api: Add ban_eos_token and add_bos_token support Adds the ability for the client to specify whether to add the BOS token and ban the EOS token. Signed-off-by: kingbri --- OAI/types/completions.py | 4 ++++ OAI/types/token.py | 4 ++-- model.py | 15 +++++++++++++-- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/OAI/types/completions.py b/OAI/types/completions.py index cd95ff5..ffe84c7 100644 --- a/OAI/types/completions.py +++ b/OAI/types/completions.py @@ -54,6 +54,8 @@ class CompletionRequest(BaseModel): mirostat_mode: Optional[int] = 0 mirostat_tau: Optional[float] = 1.5 mirostat_eta: Optional[float] = 0.1 + add_bos_token: Optional[bool] = True + ban_eos_token: Optional[bool] = False # Converts to internal generation parameters def to_gen_params(self): @@ -73,6 +75,8 @@ class CompletionRequest(BaseModel): "prompt": self.prompt, "stop": self.stop, "max_tokens": self.max_tokens, + "add_bos_token": self.add_bos_token, + "ban_eos_token": self.ban_eos_token, "token_healing": self.token_healing, "temperature": self.temperature, "top_k": self.top_k, diff --git a/OAI/types/token.py b/OAI/types/token.py index fdc59c6..a0bf3f9 100644 --- a/OAI/types/token.py +++ b/OAI/types/token.py @@ -2,13 +2,13 @@ from pydantic import BaseModel from typing import List class CommonTokenRequest(BaseModel): - add_bos: bool = True + add_bos_token: bool = True encode_special_tokens: bool = True decode_special_tokens: bool = True def get_params(self): return { - "add_bos": self.add_bos, + "add_bos_token": self.add_bos_token, "encode_special_tokens": self.encode_special_tokens, "decode_special_tokens": self.decode_special_tokens } diff --git a/model.py b/model.py index e07291c..1ea9da4 100644 --- a/model.py +++ b/model.py @@ -200,7 +200,8 @@ class ModelContainer: if text: # Assume token encoding return self.tokenizer.encode( - text, add_bos = kwargs.get("add_bos", True), + text, + add_bos = kwargs.get("add_bos_token", True), encode_special_tokens = kwargs.get("encode_special_tokens", True) ) if ids: @@ -236,6 +237,8 @@ class ModelContainer: 'repetition_decay' (int): Repetition penalty range (default: same as range) 'stop' (list): List of stop strings/tokens to end response (default: [EOS]) 'max_tokens' (int): Max no. tokens in response (default: 150) + 'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True) + 'ban_eos_token' (bool): Bans the EOS token from generation (default: False) 'stream_interval' (float): Interval in seconds between each output chunk (default: immediate) 'generate_window' (int): Space to reserve at the end of the model's context when generating. Rolls context window by the same amount if context length is exceeded to allow generating past @@ -266,6 +269,10 @@ class ModelContainer: gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len) gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range) + # Ban the EOS token if specified + if kwargs.get("ban_eos_token", False): + gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) + # Override sampler settings for temp = 0 if gen_settings.temperature == 0: @@ -280,7 +287,11 @@ class ModelContainer: # Tokenized context - ids = self.tokenizer.encode(prompt, encode_special_tokens = True) + ids = self.tokenizer.encode( + prompt, + add_bos=kwargs.get("add_bos_token", True), + encode_special_tokens = True + ) # Begin