diff --git a/OAI/types/common.py b/OAI/types/common.py index 065ec5d..ca636b9 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -21,7 +21,6 @@ class CommonCompletionRequest(BaseModel): # Extra OAI request stuff best_of: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) echo: Optional[bool] = Field(description = "Not parsed. Only used for OAI compliance.", default = False) - logit_bias: Optional[Dict[str, float]] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) logprobs: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) n: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = 1) suffix: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) @@ -55,6 +54,7 @@ class CommonCompletionRequest(BaseModel): mirostat_eta: Optional[float] = 0.1 add_bos_token: Optional[bool] = True ban_eos_token: Optional[bool] = False + logit_bias: Optional[Dict[int, float]] = None # Aliased variables repetition_range: Optional[int] = Field( @@ -78,6 +78,7 @@ class CommonCompletionRequest(BaseModel): "add_bos_token": self.add_bos_token, "ban_eos_token": self.ban_eos_token, "token_healing": self.token_healing, + "logit_bias": self.logit_bias, "temperature": self.temperature, "temperature_last": self.temperature_last, "top_k": self.top_k, diff --git a/model.py b/model.py index ff418ee..5c8be3d 100644 --- a/model.py +++ b/model.py @@ -345,6 +345,7 @@ class ModelContainer: 'max_tokens' (int): Max no. tokens in response (default: 150) 'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True) 'ban_eos_token' (bool): Bans the EOS token from generation (default: False) + 'logit_bias' (Dict[int, float]): Biases specific tokens to either show up more or less (default: None) 'stream_interval' (float): Interval in seconds between each output chunk (default: immediate) 'generate_window' (int): Space to reserve at the end of the model's context when generating. Rolls context window by the same amount if context length is exceeded to allow generating past @@ -396,7 +397,9 @@ class ModelContainer: gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0) stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) + add_bos_token = unwrap(kwargs.get("add_bos_token"), True) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) + logit_bias = kwargs.get("logit_bias") # Override sampler settings for temp = 0 if gen_settings.temperature == 0: @@ -406,16 +409,31 @@ class ModelContainer: gen_settings.typical = 0 # Log generation options to console + # Some options are too large, so log the args instead log_generation_params( + max_tokens = max_tokens, **vars(gen_settings), token_healing = token_healing, - max_tokens = max_tokens, - stop_conditions = stop_conditions + add_bos_token = add_bos_token, + ban_eos_token = ban_eos_token, + stop_conditions = stop_conditions, + logit_bias = logit_bias ) # Log prompt to console log_prompt(prompt) + # Set logit bias + if logit_bias: + # Create a vocab tensor if it doesn't exist for token biasing + if gen_settings.token_bias is None: + padding = -self.tokenizer.config.vocab_size % 32 + gen_settings.token_bias = torch.zeros((self.tokenizer.config.vocab_size + padding,), dtype = torch.float) + + # Map logits to the tensor with their biases + for token, bias in logit_bias.items(): + gen_settings.token_bias[token] = bias + # Ban the EOS token if specified. If not, append to stop conditions as well. # Set this below logging to avoid polluting the stop strings array if ban_eos_token: @@ -429,7 +447,7 @@ class ModelContainer: # Tokenized context ids = self.tokenizer.encode( prompt, - add_bos = unwrap(kwargs.get("add_bos_token"), True), + add_bos = add_bos_token, encode_special_tokens = True ) context_len = len(ids[0])