OAI: Add logit bias support

Use exllamav2's token bias which is the functional equivalent of
OAI's logit bias parameter.

Strings are casted to integers on request and errors if an invalid
value is passed.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-18 18:04:12 -05:00
committed by Brian Dashore
parent 46f6dc824e
commit c3f7898967
2 changed files with 23 additions and 4 deletions

View File

@@ -21,7 +21,6 @@ class CommonCompletionRequest(BaseModel):
# Extra OAI request stuff
best_of: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
echo: Optional[bool] = Field(description = "Not parsed. Only used for OAI compliance.", default = False)
logit_bias: Optional[Dict[str, float]] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
logprobs: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
n: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = 1)
suffix: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
@@ -55,6 +54,7 @@ class CommonCompletionRequest(BaseModel):
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
logit_bias: Optional[Dict[int, float]] = None
# Aliased variables
repetition_range: Optional[int] = Field(
@@ -78,6 +78,7 @@ class CommonCompletionRequest(BaseModel):
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"top_k": self.top_k,

View File

@@ -345,6 +345,7 @@ class ModelContainer:
'max_tokens' (int): Max no. tokens in response (default: 150)
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
'logit_bias' (Dict[int, float]): Biases specific tokens to either show up more or less (default: None)
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
'generate_window' (int): Space to reserve at the end of the model's context when generating.
Rolls context window by the same amount if context length is exceeded to allow generating past
@@ -396,7 +397,9 @@ class ModelContainer:
gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0)
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
logit_bias = kwargs.get("logit_bias")
# Override sampler settings for temp = 0
if gen_settings.temperature == 0:
@@ -406,16 +409,31 @@ class ModelContainer:
gen_settings.typical = 0
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
max_tokens = max_tokens,
**vars(gen_settings),
token_healing = token_healing,
max_tokens = max_tokens,
stop_conditions = stop_conditions
add_bos_token = add_bos_token,
ban_eos_token = ban_eos_token,
stop_conditions = stop_conditions,
logit_bias = logit_bias
)
# Log prompt to console
log_prompt(prompt)
# Set logit bias
if logit_bias:
# Create a vocab tensor if it doesn't exist for token biasing
if gen_settings.token_bias is None:
padding = -self.tokenizer.config.vocab_size % 32
gen_settings.token_bias = torch.zeros((self.tokenizer.config.vocab_size + padding,), dtype = torch.float)
# Map logits to the tensor with their biases
for token, bias in logit_bias.items():
gen_settings.token_bias[token] = bias
# Ban the EOS token if specified. If not, append to stop conditions as well.
# Set this below logging to avoid polluting the stop strings array
if ban_eos_token:
@@ -429,7 +447,7 @@ class ModelContainer:
# Tokenized context
ids = self.tokenizer.encode(
prompt,
add_bos = unwrap(kwargs.get("add_bos_token"), True),
add_bos = add_bos_token,
encode_special_tokens = True
)
context_len = len(ids[0])