OAI: Add logit bias support

Use exllamav2's token bias which is the functional equivalent of
OAI's logit bias parameter.

Strings are casted to integers on request and errors if an invalid
value is passed.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-18 18:04:12 -05:00
committed by Brian Dashore
parent 46f6dc824e
commit c3f7898967
2 changed files with 23 additions and 4 deletions

View File

@@ -21,7 +21,6 @@ class CommonCompletionRequest(BaseModel):
# Extra OAI request stuff
best_of: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
echo: Optional[bool] = Field(description = "Not parsed. Only used for OAI compliance.", default = False)
logit_bias: Optional[Dict[str, float]] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
logprobs: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
n: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = 1)
suffix: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
@@ -55,6 +54,7 @@ class CommonCompletionRequest(BaseModel):
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
logit_bias: Optional[Dict[int, float]] = None
# Aliased variables
repetition_range: Optional[int] = Field(
@@ -78,6 +78,7 @@ class CommonCompletionRequest(BaseModel):
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"top_k": self.top_k,