Sampler: Add allow_tokens

This commit is contained in:
turboderp
2024-06-01 03:30:31 +02:00
parent 1027424755
commit 74b49ba28b

View File

@@ -91,6 +91,26 @@ class ExLlamaV2Sampler:
self.token_bias[tokens] = float("-inf")
def allow_tokens(
self,
tokenizer: ExLlamaV2Tokenizer,
tokens: list[int | str]
):
"""Utility function to set/update the logit bias, disallowing all but specific tokens in the supplied list"""
if self.token_bias is None:
padding = -tokenizer.config.vocab_size % 32
self.token_bias = torch.full((tokenizer.config.vocab_size + padding,), float("-inf"), dtype = torch.float)
for t in tokens:
if isinstance(t, int):
self.token_bias[t] = 0.0
elif isinstance(t, str):
self.token_bias[tokenizer.single_id(t)] = 0.0
else:
raise ValueError("Incorrect type in allow_tokens list")
@staticmethod
# @profile
def sample(