Allow multiple valid prefixes in ExLlamaV2PrefixFilter

This commit is contained in:
turboderp
2024-06-03 19:16:59 +02:00
parent 4a07955e50
commit 127d4c70e5
2 changed files with 51 additions and 35 deletions

View File

@@ -61,7 +61,7 @@ for p in i_prompts:
prompts.append(p)
filters.append([
ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer),
ExLlamaV2PrefixFilter(model, tokenizer, "{")
ExLlamaV2PrefixFilter(model, tokenizer, ["{", " {"])
])
# Generate
@@ -91,5 +91,5 @@ for i in range(len(i_prompts)):
print()
print("With filter:")
print("------------")
print(json.dumps(json.loads(outputs[i * 2 + 1]), indent = 4))
print(json.dumps(json.loads(outputs[i * 2 + 1]), indent = 4).strip())
print()

View File

@@ -1,75 +1,91 @@
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Tokenizer
)
from __future__ import annotations
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters.base import ExLlamaV2Filter
class ExLlamaV2PrefixFilter(ExLlamaV2Filter):
offset: int
prefix_string: str
prefix_strings: list[str]
current_prefixes: set[str]
current_str: str
def __init__(self,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
prefix_string: str):
prefix_strings: str | list[str]):
"""
:param prefix_string:
Force generation to start with the specified string.
:param prefix_strings:
Force generation to start with one of the specified strings. Note that if two strings have a shared
prefix, only the shorter of the two is effective, since matching the shorter prefix is enough to fully
satisfy the constraint. I.e. ["story", "storytime"] is effectively the same constraint as ["story"]
"""
super().__init__(model, tokenizer)
self.prefix_string = prefix_string
self.offset = 0
self.prefix_strings = prefix_strings if isinstance(prefix_strings, list) else [prefix_strings]
self.current_prefixes = set()
self.current_str = ""
def clone(self, c = None):
if c is None:
c = ExLlamaV2PrefixFilter.__new__(ExLlamaV2PrefixFilter)
super().clone(c)
c.offset = self.offset
c.prefix_string = self.prefix_string
c.prefix_strings = self.prefix_strings
c.current_prefixes = self.current_prefixes
c.current_str = self.current_str
return c
def begin(self, prefix_str: str = ""):
self.offset = 0
self.current_prefixes = set(self.prefix_strings)
self.current_str = ""
def feed(self, token: int):
id_to_piece = self.tokenizer.get_id_to_piece_list()
piece = id_to_piece[token]
self.offset += len(piece)
self.current_str += piece
end_prefixes = set()
for prefix in self.current_prefixes:
if not prefix[:len(self.current_str)] == self.current_str:
end_prefixes.add(prefix)
self.current_prefixes -= end_prefixes
def next(self):
if self.offset >= len(self.prefix_string):
min_valid_length = 0 if not self.current_prefixes else min(len(s) for s in self.current_prefixes)
if len(self.current_str) >= min_valid_length:
return None, set()
char_trie = self.tokenizer.get_char_trie()
prefix_to_ids = self.tokenizer.get_prefix_to_ids_dict()
pass_tokens_all = set()
for prefix in self.current_prefixes:
rem_str = self.prefix_string[self.offset:]
char_trie = self.tokenizer.get_char_trie()
prefix_to_ids = self.tokenizer.get_prefix_to_ids_dict()
# Use prefix dict if string could be completed by one token
rem_str = prefix[len(self.current_str):]
if rem_str in prefix_to_ids:
pass_tokens = set(prefix_to_ids[rem_str])
else:
pass_tokens = set()
# Use prefix dict if string could be completed by one token
# Find tokens that would advance along the prefix from the current offset
for c in rem_str:
if c in char_trie.children:
char_trie = char_trie.children[c]
if rem_str in prefix_to_ids:
pass_tokens = set(prefix_to_ids[rem_str])
else:
break
pass_tokens |= set(char_trie.leaf)
pass_tokens = set()
return pass_tokens, set()
# Find tokens that would advance along the prefix from the current offset
for c in rem_str:
if c in char_trie.children:
char_trie = char_trie.children[c]
else:
break
pass_tokens |= set(char_trie.leaf)
pass_tokens_all |= pass_tokens
return pass_tokens_all, set()