From 6f9da97114422b35ddb4e7c3df99c198c1855251 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 28 Apr 2024 00:40:34 -0400 Subject: [PATCH] API: Add banned_tokens Appends the banned tokens to the generation. This is equivalent of setting logit bias to -100 on a specific set of tokens. Signed-off-by: kingbri --- backends/exllamav2/model.py | 6 ++++++ common/sampling.py | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3cf4400..2df3bd5 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -813,6 +813,11 @@ class ExllamaV2Container: # Store the gen settings for logging purposes gen_settings_log_dict = vars(gen_settings) + # Set banned tokens + banned_tokens = unwrap(kwargs.get("banned_tokens"), []) + if banned_tokens: + gen_settings.disallow_tokens(self.tokenizer, banned_tokens) + # Set logit bias if logit_bias: # Create a vocab tensor if it doesn't exist for token biasing @@ -953,6 +958,7 @@ class ExllamaV2Container: speculative_ngram=self.generator.speculative_ngram, logprobs=request_logprobs, stop_conditions=stop_conditions, + banned_tokens=banned_tokens, logit_bias=logit_bias, ) diff --git a/common/sampling.py b/common/sampling.py index e8be21b..c782042 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -177,6 +177,13 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) + banned_tokens: Optional[Union[List[int], str]] = Field( + default_factory=lambda: get_default_sampler_value("banned_tokens", []), + validation_alias=AliasChoices("banned_tokens", "custom_token_bans"), + description="Aliases: custom_token_bans", + examples=[[128, 330]], + ) + # TODO: Return back to adaptable class-based validation But that's just too much # abstraction compared to simple if statements at the moment def validate_params(self): @@ -245,6 +252,9 @@ class BaseSamplerRequest(BaseModel): if isinstance(self.stop, str): self.stop = [self.stop] + if isinstance(self.banned_tokens, str): + self.banned_tokens = list(map(int, self.banned_tokens.split(","))) + gen_params = { "max_tokens": self.max_tokens, "generate_window": self.generate_window, @@ -254,6 +264,7 @@ class BaseSamplerRequest(BaseModel): "skip_special_tokens": self.skip_special_tokens, "token_healing": self.token_healing, "logit_bias": self.logit_bias, + "banned_tokens": self.banned_tokens, "temperature": self.temperature, "temperature_last": self.temperature_last, "min_temp": self.min_temp,