mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Allow multiple valid prefixes in ExLlamaV2PrefixFilter
This commit is contained in:
@@ -61,7 +61,7 @@ for p in i_prompts:
|
||||
prompts.append(p)
|
||||
filters.append([
|
||||
ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer),
|
||||
ExLlamaV2PrefixFilter(model, tokenizer, "{")
|
||||
ExLlamaV2PrefixFilter(model, tokenizer, ["{", " {"])
|
||||
])
|
||||
|
||||
# Generate
|
||||
@@ -91,5 +91,5 @@ for i in range(len(i_prompts)):
|
||||
print()
|
||||
print("With filter:")
|
||||
print("------------")
|
||||
print(json.dumps(json.loads(outputs[i * 2 + 1]), indent = 4))
|
||||
print(json.dumps(json.loads(outputs[i * 2 + 1]), indent = 4).strip())
|
||||
print()
|
||||
@@ -1,75 +1,91 @@
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Tokenizer
|
||||
)
|
||||
|
||||
from __future__ import annotations
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.generator.filters.base import ExLlamaV2Filter
|
||||
|
||||
class ExLlamaV2PrefixFilter(ExLlamaV2Filter):
|
||||
|
||||
offset: int
|
||||
prefix_string: str
|
||||
prefix_strings: list[str]
|
||||
current_prefixes: set[str]
|
||||
current_str: str
|
||||
|
||||
def __init__(self,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
prefix_string: str):
|
||||
prefix_strings: str | list[str]):
|
||||
"""
|
||||
:param prefix_string:
|
||||
Force generation to start with the specified string.
|
||||
:param prefix_strings:
|
||||
Force generation to start with one of the specified strings. Note that if two strings have a shared
|
||||
prefix, only the shorter of the two is effective, since matching the shorter prefix is enough to fully
|
||||
satisfy the constraint. I.e. ["story", "storytime"] is effectively the same constraint as ["story"]
|
||||
|
||||
"""
|
||||
|
||||
super().__init__(model, tokenizer)
|
||||
|
||||
self.prefix_string = prefix_string
|
||||
self.offset = 0
|
||||
self.prefix_strings = prefix_strings if isinstance(prefix_strings, list) else [prefix_strings]
|
||||
self.current_prefixes = set()
|
||||
self.current_str = ""
|
||||
|
||||
|
||||
def clone(self, c = None):
|
||||
if c is None:
|
||||
c = ExLlamaV2PrefixFilter.__new__(ExLlamaV2PrefixFilter)
|
||||
super().clone(c)
|
||||
c.offset = self.offset
|
||||
c.prefix_string = self.prefix_string
|
||||
c.prefix_strings = self.prefix_strings
|
||||
c.current_prefixes = self.current_prefixes
|
||||
c.current_str = self.current_str
|
||||
return c
|
||||
|
||||
|
||||
def begin(self, prefix_str: str = ""):
|
||||
|
||||
self.offset = 0
|
||||
self.current_prefixes = set(self.prefix_strings)
|
||||
self.current_str = ""
|
||||
|
||||
|
||||
def feed(self, token: int):
|
||||
|
||||
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
||||
piece = id_to_piece[token]
|
||||
self.offset += len(piece)
|
||||
self.current_str += piece
|
||||
|
||||
end_prefixes = set()
|
||||
for prefix in self.current_prefixes:
|
||||
if not prefix[:len(self.current_str)] == self.current_str:
|
||||
end_prefixes.add(prefix)
|
||||
self.current_prefixes -= end_prefixes
|
||||
|
||||
|
||||
def next(self):
|
||||
|
||||
if self.offset >= len(self.prefix_string):
|
||||
min_valid_length = 0 if not self.current_prefixes else min(len(s) for s in self.current_prefixes)
|
||||
if len(self.current_str) >= min_valid_length:
|
||||
return None, set()
|
||||
|
||||
char_trie = self.tokenizer.get_char_trie()
|
||||
prefix_to_ids = self.tokenizer.get_prefix_to_ids_dict()
|
||||
pass_tokens_all = set()
|
||||
for prefix in self.current_prefixes:
|
||||
|
||||
rem_str = self.prefix_string[self.offset:]
|
||||
char_trie = self.tokenizer.get_char_trie()
|
||||
prefix_to_ids = self.tokenizer.get_prefix_to_ids_dict()
|
||||
|
||||
# Use prefix dict if string could be completed by one token
|
||||
rem_str = prefix[len(self.current_str):]
|
||||
|
||||
if rem_str in prefix_to_ids:
|
||||
pass_tokens = set(prefix_to_ids[rem_str])
|
||||
else:
|
||||
pass_tokens = set()
|
||||
# Use prefix dict if string could be completed by one token
|
||||
|
||||
# Find tokens that would advance along the prefix from the current offset
|
||||
|
||||
for c in rem_str:
|
||||
if c in char_trie.children:
|
||||
char_trie = char_trie.children[c]
|
||||
if rem_str in prefix_to_ids:
|
||||
pass_tokens = set(prefix_to_ids[rem_str])
|
||||
else:
|
||||
break
|
||||
pass_tokens |= set(char_trie.leaf)
|
||||
pass_tokens = set()
|
||||
|
||||
return pass_tokens, set()
|
||||
# Find tokens that would advance along the prefix from the current offset
|
||||
|
||||
for c in rem_str:
|
||||
if c in char_trie.children:
|
||||
char_trie = char_trie.children[c]
|
||||
else:
|
||||
break
|
||||
pass_tokens |= set(char_trie.leaf)
|
||||
|
||||
pass_tokens_all |= pass_tokens
|
||||
|
||||
return pass_tokens_all, set()
|
||||
Reference in New Issue
Block a user