mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-25 16:59:09 +00:00
Model: Fix parsing of stop conditions
Add the EOS token into stop strings after checking kwargs. If ban_eos_token is on, don't add the EOS token in for extra measure. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
18
model.py
18
model.py
@@ -11,7 +11,7 @@ from exllamav2.generator import(
|
|||||||
ExLlamaV2StreamingGenerator,
|
ExLlamaV2StreamingGenerator,
|
||||||
ExLlamaV2Sampler
|
ExLlamaV2Sampler
|
||||||
)
|
)
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
# Bytes to reserve on first device when loading with auto split
|
# Bytes to reserve on first device when loading with auto split
|
||||||
auto_split_reserve_bytes = 96 * 1024**2
|
auto_split_reserve_bytes = 96 * 1024**2
|
||||||
@@ -237,7 +237,7 @@ class ModelContainer:
|
|||||||
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
|
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
|
||||||
'repetition_range' (int): Repetition penalty range (default: whole context)
|
'repetition_range' (int): Repetition penalty range (default: whole context)
|
||||||
'repetition_decay' (int): Repetition penalty range (default: same as range)
|
'repetition_decay' (int): Repetition penalty range (default: same as range)
|
||||||
'stop' (list): List of stop strings/tokens to end response (default: [EOS])
|
'stop' (List[Union[str, int]]): List of stop strings/tokens to end response (default: [EOS])
|
||||||
'max_tokens' (int): Max no. tokens in response (default: 150)
|
'max_tokens' (int): Max no. tokens in response (default: 150)
|
||||||
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
|
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
|
||||||
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
|
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
|
||||||
@@ -271,9 +271,15 @@ class ModelContainer:
|
|||||||
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
|
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
|
||||||
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
|
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
|
||||||
|
|
||||||
# Ban the EOS token if specified
|
stop_conditions: List[Union[str, int]] = kwargs.get("stop", [])
|
||||||
if kwargs.get("ban_eos_token", False):
|
ban_eos_token = kwargs.get("ban_eos_token", False)
|
||||||
|
|
||||||
|
# Ban the EOS token if specified. If not, append to stop conditions as well.
|
||||||
|
|
||||||
|
if ban_eos_token:
|
||||||
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||||
|
else:
|
||||||
|
stop_conditions.append(self.tokenizer.eos_token_id)
|
||||||
|
|
||||||
# Override sampler settings for temp = 0
|
# Override sampler settings for temp = 0
|
||||||
|
|
||||||
@@ -283,9 +289,9 @@ class ModelContainer:
|
|||||||
gen_settings.top_p = 0
|
gen_settings.top_p = 0
|
||||||
gen_settings.typical = 0
|
gen_settings.typical = 0
|
||||||
|
|
||||||
# Stop conditions
|
# Stop conditions
|
||||||
|
|
||||||
self.generator.set_stop_conditions(kwargs.get("stop", [self.tokenizer.eos_token_id]))
|
self.generator.set_stop_conditions(stop_conditions)
|
||||||
|
|
||||||
# Tokenized context
|
# Tokenized context
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user