From c3f78989673b638a5c0b7c0821e765b402cad178 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 18 Dec 2023 18:04:12 -0500 Subject: [PATCH] OAI: Add logit bias support Use exllamav2's token bias which is the functional equivalent of OAI's logit bias parameter. Strings are casted to integers on request and errors if an invalid value is passed. Signed-off-by: kingbri --- OAI/types/common.py | 3 ++- model.py | 24 +++++++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/OAI/types/common.py b/OAI/types/common.py index 065ec5d..ca636b9 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -21,7 +21,6 @@ class CommonCompletionRequest(BaseModel): # Extra OAI request stuff best_of: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) echo: Optional[bool] = Field(description = "Not parsed. Only used for OAI compliance.", default = False) - logit_bias: Optional[Dict[str, float]] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) logprobs: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) n: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = 1) suffix: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) @@ -55,6 +54,7 @@ class CommonCompletionRequest(BaseModel): mirostat_eta: Optional[float] = 0.1 add_bos_token: Optional[bool] = True ban_eos_token: Optional[bool] = False + logit_bias: Optional[Dict[int, float]] = None # Aliased variables repetition_range: Optional[int] = Field( @@ -78,6 +78,7 @@ class CommonCompletionRequest(BaseModel): "add_bos_token": self.add_bos_token, "ban_eos_token": self.ban_eos_token, "token_healing": self.token_healing, + "logit_bias": self.logit_bias, "temperature": self.temperature, "temperature_last": self.temperature_last, "top_k": self.top_k, diff --git a/model.py b/model.py index ff418ee..5c8be3d 100644 --- a/model.py +++ b/model.py @@ -345,6 +345,7 @@ class ModelContainer: 'max_tokens' (int): Max no. tokens in response (default: 150) 'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True) 'ban_eos_token' (bool): Bans the EOS token from generation (default: False) + 'logit_bias' (Dict[int, float]): Biases specific tokens to either show up more or less (default: None) 'stream_interval' (float): Interval in seconds between each output chunk (default: immediate) 'generate_window' (int): Space to reserve at the end of the model's context when generating. Rolls context window by the same amount if context length is exceeded to allow generating past @@ -396,7 +397,9 @@ class ModelContainer: gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0) stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) + add_bos_token = unwrap(kwargs.get("add_bos_token"), True) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) + logit_bias = kwargs.get("logit_bias") # Override sampler settings for temp = 0 if gen_settings.temperature == 0: @@ -406,16 +409,31 @@ class ModelContainer: gen_settings.typical = 0 # Log generation options to console + # Some options are too large, so log the args instead log_generation_params( + max_tokens = max_tokens, **vars(gen_settings), token_healing = token_healing, - max_tokens = max_tokens, - stop_conditions = stop_conditions + add_bos_token = add_bos_token, + ban_eos_token = ban_eos_token, + stop_conditions = stop_conditions, + logit_bias = logit_bias ) # Log prompt to console log_prompt(prompt) + # Set logit bias + if logit_bias: + # Create a vocab tensor if it doesn't exist for token biasing + if gen_settings.token_bias is None: + padding = -self.tokenizer.config.vocab_size % 32 + gen_settings.token_bias = torch.zeros((self.tokenizer.config.vocab_size + padding,), dtype = torch.float) + + # Map logits to the tensor with their biases + for token, bias in logit_bias.items(): + gen_settings.token_bias[token] = bias + # Ban the EOS token if specified. If not, append to stop conditions as well. # Set this below logging to avoid polluting the stop strings array if ban_eos_token: @@ -429,7 +447,7 @@ class ModelContainer: # Tokenized context ids = self.tokenizer.encode( prompt, - add_bos = unwrap(kwargs.get("add_bos_token"), True), + add_bos = add_bos_token, encode_special_tokens = True ) context_len = len(ids[0])