mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Sampler: Add allow_tokens
This commit is contained in:
@@ -91,6 +91,26 @@ class ExLlamaV2Sampler:
|
||||
self.token_bias[tokens] = float("-inf")
|
||||
|
||||
|
||||
def allow_tokens(
|
||||
self,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
tokens: list[int | str]
|
||||
):
|
||||
"""Utility function to set/update the logit bias, disallowing all but specific tokens in the supplied list"""
|
||||
|
||||
if self.token_bias is None:
|
||||
padding = -tokenizer.config.vocab_size % 32
|
||||
self.token_bias = torch.full((tokenizer.config.vocab_size + padding,), float("-inf"), dtype = torch.float)
|
||||
|
||||
for t in tokens:
|
||||
if isinstance(t, int):
|
||||
self.token_bias[t] = 0.0
|
||||
elif isinstance(t, str):
|
||||
self.token_bias[tokenizer.single_id(t)] = 0.0
|
||||
else:
|
||||
raise ValueError("Incorrect type in allow_tokens list")
|
||||
|
||||
|
||||
@staticmethod
|
||||
# @profile
|
||||
def sample(
|
||||
|
||||
Reference in New Issue
Block a user