mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
OAI: Add logit bias support
Use exllamav2's token bias which is the functional equivalent of OAI's logit bias parameter. Strings are casted to integers on request and errors if an invalid value is passed. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -21,7 +21,6 @@ class CommonCompletionRequest(BaseModel):
|
||||
# Extra OAI request stuff
|
||||
best_of: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
|
||||
echo: Optional[bool] = Field(description = "Not parsed. Only used for OAI compliance.", default = False)
|
||||
logit_bias: Optional[Dict[str, float]] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
|
||||
logprobs: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
|
||||
n: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = 1)
|
||||
suffix: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
|
||||
@@ -55,6 +54,7 @@ class CommonCompletionRequest(BaseModel):
|
||||
mirostat_eta: Optional[float] = 0.1
|
||||
add_bos_token: Optional[bool] = True
|
||||
ban_eos_token: Optional[bool] = False
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
|
||||
# Aliased variables
|
||||
repetition_range: Optional[int] = Field(
|
||||
@@ -78,6 +78,7 @@ class CommonCompletionRequest(BaseModel):
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"ban_eos_token": self.ban_eos_token,
|
||||
"token_healing": self.token_healing,
|
||||
"logit_bias": self.logit_bias,
|
||||
"temperature": self.temperature,
|
||||
"temperature_last": self.temperature_last,
|
||||
"top_k": self.top_k,
|
||||
|
||||
24
model.py
24
model.py
@@ -345,6 +345,7 @@ class ModelContainer:
|
||||
'max_tokens' (int): Max no. tokens in response (default: 150)
|
||||
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
|
||||
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
|
||||
'logit_bias' (Dict[int, float]): Biases specific tokens to either show up more or less (default: None)
|
||||
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
|
||||
'generate_window' (int): Space to reserve at the end of the model's context when generating.
|
||||
Rolls context window by the same amount if context length is exceeded to allow generating past
|
||||
@@ -396,7 +397,9 @@ class ModelContainer:
|
||||
gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0)
|
||||
|
||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
||||
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
||||
logit_bias = kwargs.get("logit_bias")
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
if gen_settings.temperature == 0:
|
||||
@@ -406,16 +409,31 @@ class ModelContainer:
|
||||
gen_settings.typical = 0
|
||||
|
||||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
max_tokens = max_tokens,
|
||||
**vars(gen_settings),
|
||||
token_healing = token_healing,
|
||||
max_tokens = max_tokens,
|
||||
stop_conditions = stop_conditions
|
||||
add_bos_token = add_bos_token,
|
||||
ban_eos_token = ban_eos_token,
|
||||
stop_conditions = stop_conditions,
|
||||
logit_bias = logit_bias
|
||||
)
|
||||
|
||||
# Log prompt to console
|
||||
log_prompt(prompt)
|
||||
|
||||
# Set logit bias
|
||||
if logit_bias:
|
||||
# Create a vocab tensor if it doesn't exist for token biasing
|
||||
if gen_settings.token_bias is None:
|
||||
padding = -self.tokenizer.config.vocab_size % 32
|
||||
gen_settings.token_bias = torch.zeros((self.tokenizer.config.vocab_size + padding,), dtype = torch.float)
|
||||
|
||||
# Map logits to the tensor with their biases
|
||||
for token, bias in logit_bias.items():
|
||||
gen_settings.token_bias[token] = bias
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions as well.
|
||||
# Set this below logging to avoid polluting the stop strings array
|
||||
if ban_eos_token:
|
||||
@@ -429,7 +447,7 @@ class ModelContainer:
|
||||
# Tokenized context
|
||||
ids = self.tokenizer.encode(
|
||||
prompt,
|
||||
add_bos = unwrap(kwargs.get("add_bos_token"), True),
|
||||
add_bos = add_bos_token,
|
||||
encode_special_tokens = True
|
||||
)
|
||||
context_len = len(ids[0])
|
||||
|
||||
Reference in New Issue
Block a user