Add prefix filter

This commit is contained in:
turboderp
2024-02-18 23:58:25 +01:00
parent 26f4bf8997
commit daf7844d18
2 changed files with 59 additions and 0 deletions

View File

@@ -2,3 +2,4 @@ from exllamav2.version import __version__
from exllamav2.generator.filters.base import ExLlamaV2Filter
from exllamav2.generator.filters.select import ExLlamaV2SelectFilter
from exllamav2.generator.filters.prefix import ExLlamaV2PrefixFilter

View File

@@ -0,0 +1,58 @@
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Tokenizer
)
from exllamav2.generator.filters.base import ExLlamaV2Filter
class ExLlamaV2PrefixFilter(ExLlamaV2Filter):
offset: int
prefix_string: str
def __init__(self, model, tokenizer, prefix_string):
super().__init__(model, tokenizer)
self.prefix_string = prefix_string
self.offset = 0
def begin(self, prefix_str = ""):
self.offset = 0
def feed(self, token):
id_to_piece = self.tokenizer.get_id_to_piece_list()
piece = id_to_piece[token]
self.offset += len(piece)
def next(self):
if self.offset >= len(self.prefix_string):
return None, set()
char_trie = self.tokenizer.get_char_trie()
prefix_to_ids = self.tokenizer.get_prefix_to_ids_dict()
rem_str = self.prefix_string[self.offset:]
# Use prefix dict if string could be completed by one token
if rem_str in prefix_to_ids:
pass_tokens = set(prefix_to_ids[rem_str])
else:
pass_tokens = set()
# Find tokens that would advance along the prefix from the current offset
for c in rem_str:
if c in char_trie.children:
char_trie = char_trie.children[c]
else:
break
pass_tokens |= set(char_trie.leaf)
return pass_tokens, set()